diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..bc34d1d00a09644e3a67e48e4c78fdc9ba58d104 Binary files /dev/null and b/.DS_Store differ diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..26d33521af10bcc7fd8cea344038eaaeb78d0ef5 --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,3 @@ +# Default ignored files +/shelf/ +/workspace.xml diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000000000000000000000000000000000000..2207299f02f75c9370600b47ce77051ca3b54130 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,29 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000000000000000000000000000000000000..105ce2da2d6447d11dfe32bfb846c3d5b199fc99 --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/materials.mhg-ged.iml b/.idea/materials.mhg-ged.iml new file mode 100644 index 0000000000000000000000000000000000000000..039314de6c082718b0c4495bd64f8df7279e477f --- /dev/null +++ b/.idea/materials.mhg-ged.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000000000000000000000000000000000000..6c89280f83fddd489fc6c07c979f7f1ba7c4969f --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000000000000000000000000000000000000..35eb1ddfbbc029bcab630581847471d7f238ec53 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/README.md b/README.md index 7b95401dc46245ac339fc25059d4a56d90b4cde5..9709c8cca394a4de06e80fb7ea7a711b8d3e6749 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,78 @@ ---- -license: apache-2.0 ---- +--- +license: apache-2.0 +--- +# mhg-gnn + +This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network" + +**Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374) + +![mhg-gnn](images/mhg_example1.png) + +## Introduction + +We present MHG-GNN, an autoencoder architecture +that has an encoder based on GNN and a decoder based on a sequential model with MHG. +Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and +demonstrate high predictive performance on molecular graph data. +In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output. + +## Table of Contents + +1. [Getting Started](#getting-started) + 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs) + 2. [Installation](#installation) +2. [Feature Extraction](#feature-extraction) + +## Getting Started + +**This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.** + +### Pretrained Models and Training Logs + +We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]() + +Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs. + +### Installation + +We recommend to create a virtual environment. For example: + +``` +python3 -m venv .venv +. .venv/bin/activate +``` + +Type the following command once the virtual environment is activated: + +``` +git clone git@github.ibm.com:CMD-TRL/mhg-gnn.git +cd ./mhg-gnn +pip install . +``` + +## Feature Extraction + +The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks. + +To load mhg-gnn, you can simply use: + +```python +import torch +import load + +model = load.load() +``` + +To encode SMILES into embeddings, you can use: + +```python +with torch.no_grad(): + repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"]) +``` + +For decoder, you can use the function, so you can return from embeddings to SMILES strings: + +```python +orig = model.decode(repr) +``` \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ed12f0d3c8ed176384f16d26197882c9ad47a36c --- /dev/null +++ b/__init__.py @@ -0,0 +1,5 @@ +# -*- coding:utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# \ No newline at end of file diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82cf6ca01ddef5690bc9ced78d22113c2a6a7528 Binary files /dev/null and b/__pycache__/__init__.cpython-310.pyc differ diff --git a/__pycache__/load.cpython-310.pyc b/__pycache__/load.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a31d1e584591bc095c35e8e539acd4f23a1c4e Binary files /dev/null and b/__pycache__/load.cpython-310.pyc differ diff --git a/graph_grammar/.DS_Store b/graph_grammar/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..c5f5f83e31db413d50194276d0d7497e37b1f13a Binary files /dev/null and b/graph_grammar/.DS_Store differ diff --git a/graph_grammar/__init__.py b/graph_grammar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..26f82acdc3d0383157745e5f0fe8ddd870325145 --- /dev/null +++ b/graph_grammar/__init__.py @@ -0,0 +1,19 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Jan 1 2018" + diff --git a/graph_grammar/__pycache__/__init__.cpython-310.pyc b/graph_grammar/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ced0ce9fa98dbac640cd04dcdbeb90278633020 Binary files /dev/null and b/graph_grammar/__pycache__/__init__.cpython-310.pyc differ diff --git a/graph_grammar/__pycache__/hypergraph.cpython-310.pyc b/graph_grammar/__pycache__/hypergraph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da6dafdb54e559d55d716ecb81242709a112a1b8 Binary files /dev/null and b/graph_grammar/__pycache__/hypergraph.cpython-310.pyc differ diff --git a/graph_grammar/algo/__init__.py b/graph_grammar/algo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e6f597c0a768b8dd64708ec70fe7071000953ce8 --- /dev/null +++ b/graph_grammar/algo/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Jan 1 2018" + diff --git a/graph_grammar/algo/__pycache__/__init__.cpython-310.pyc b/graph_grammar/algo/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18bd58db664f9fcabf573dcdc21fd6e7ef497e89 Binary files /dev/null and b/graph_grammar/algo/__pycache__/__init__.cpython-310.pyc differ diff --git a/graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc b/graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d050b46a7e787f475c853001d7510d3c8b1d4031 Binary files /dev/null and b/graph_grammar/algo/__pycache__/tree_decomposition.cpython-310.pyc differ diff --git a/graph_grammar/algo/tree_decomposition.py b/graph_grammar/algo/tree_decomposition.py new file mode 100644 index 0000000000000000000000000000000000000000..81cb7748f573c99597c8f0658555b9efd1171cfa --- /dev/null +++ b/graph_grammar/algo/tree_decomposition.py @@ -0,0 +1,821 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2017" +__version__ = "0.1" +__date__ = "Dec 11 2017" + +from copy import deepcopy +from itertools import combinations +from ..hypergraph import Hypergraph +import networkx as nx +import numpy as np + + +class CliqueTree(nx.Graph): + ''' clique tree object + + Attributes + ---------- + hg : Hypergraph + This hypergraph will be decomposed. + root_hg : Hypergraph + Hypergraph on the root node. + ident_node_dict : dict + ident_node_dict[key_node] gives a list of nodes that are identical (i.e., the adjacent hyperedges are common) + ''' + def __init__(self, hg=None, **kwargs): + self.hg = deepcopy(hg) + if self.hg is not None: + self.ident_node_dict = self.hg.get_identical_node_dict() + else: + self.ident_node_dict = {} + super().__init__(**kwargs) + + @property + def root_hg(self): + ''' return the hypergraph on the root node + ''' + return self.nodes[0]['subhg'] + + @root_hg.setter + def root_hg(self, hypergraph): + ''' set the hypergraph on the root node + ''' + self.nodes[0]['subhg'] = hypergraph + + def insert_subhg(self, subhypergraph: Hypergraph) -> None: + ''' insert a subhypergraph, which is extracted from a root hypergraph, into the tree. + + Parameters + ---------- + subhg : Hypergraph + ''' + num_nodes = self.number_of_nodes() + self.add_node(num_nodes, subhg=subhypergraph) + self.add_edge(num_nodes, 0) + adj_nodes = deepcopy(list(self.adj[0].keys())) + for each_node in adj_nodes: + if len(self.nodes[each_node]["subhg"].nodes.intersection( + self.nodes[num_nodes]["subhg"].nodes)\ + - self.root_hg.nodes) != 0 and each_node != num_nodes: + self.remove_edge(0, each_node) + self.add_edge(each_node, num_nodes) + + def to_irredundant(self) -> None: + ''' convert the clique tree to be irredundant + ''' + for each_node in self.hg.nodes: + subtree = self.subgraph([ + each_tree_node for each_tree_node in self.nodes()\ + if each_node in self.nodes[each_tree_node]["subhg"].nodes]).copy() + leaf_node_list = [x for x in subtree.nodes() if subtree.degree(x)==1] + redundant_leaf_node_list = [] + for each_leaf_node in leaf_node_list: + if len(self.nodes[each_leaf_node]["subhg"].adj_edges(each_node)) == 0: + redundant_leaf_node_list.append(each_leaf_node) + for each_red_leaf_node in redundant_leaf_node_list: + current_node = each_red_leaf_node + while subtree.degree(current_node) == 1 \ + and len(subtree.nodes[current_node]["subhg"].adj_edges(each_node)) == 0: + self.nodes[current_node]["subhg"].remove_node(each_node) + remove_node = current_node + current_node = list(dict(subtree[remove_node]).keys())[0] + subtree.remove_node(remove_node) + + fixed_node_set = deepcopy(self.nodes) + for each_node in fixed_node_set: + if self.nodes[each_node]["subhg"].num_edges == 0: + if len(self[each_node]) == 1: + self.remove_node(each_node) + elif len(self[each_node]) == 2: + self.add_edge(*self[each_node]) + self.remove_node(each_node) + else: + pass + else: + pass + + redundant = True + while redundant: + redundant = False + fixed_edge_set = deepcopy(self.edges) + remove_node_set = set() + for node_1, node_2 in fixed_edge_set: + if node_1 in remove_node_set or node_2 in remove_node_set: + pass + else: + if self.nodes[node_1]['subhg'].is_subhg(self.nodes[node_2]['subhg']): + redundant = True + adj_node_list = set(self.adj[node_1]) - {node_2} + self.remove_node(node_1) + remove_node_set.add(node_1) + for each_node in adj_node_list: + self.add_edge(node_2, each_node) + + elif self.nodes[node_2]['subhg'].is_subhg(self.nodes[node_1]['subhg']): + redundant = True + adj_node_list = set(self.adj[node_2]) - {node_1} + self.remove_node(node_2) + remove_node_set.add(node_2) + for each_node in adj_node_list: + self.add_edge(node_1, each_node) + + def node_update(self, key_node: str, subhg) -> None: + """ given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH. + + Parameters + ---------- + key_node : str + key node that must be removed. + subhg : Hypegraph + """ + for each_edge in subhg.edges: + self.root_hg.remove_edge(each_edge) + self.root_hg.remove_nodes(self.ident_node_dict[key_node]) + + adj_node_list = list(subhg.nodes) + for each_node in subhg.nodes: + if each_node not in self.ident_node_dict[key_node]: + if set(self.root_hg.adj_edges(each_node)).issubset(subhg.edges): + self.root_hg.remove_node(each_node) + adj_node_list.remove(each_node) + else: + adj_node_list.remove(each_node) + + for each_node_1, each_node_2 in combinations(adj_node_list, 2): + if not self.root_hg.is_adj(each_node_1, each_node_2): + self.root_hg.add_edge(set([each_node_1, each_node_2]), attr_dict=dict(tmp=True)) + + subhg.remove_edges_with_attr({'tmp' : True}) + self.insert_subhg(subhg) + + def update(self, subhg, remove_nodes=False): + """ given a pair of a hypergraph, H, and its subhypergraph, sH, return a hypergraph H\sH. + + Parameters + ---------- + subhg : Hypegraph + """ + for each_edge in subhg.edges: + self.root_hg.remove_edge(each_edge) + if remove_nodes: + remove_edge_list = [] + for each_edge in self.root_hg.edges: + if set(self.root_hg.nodes_in_edge(each_edge)).issubset(subhg.nodes)\ + and self.root_hg.edge_attr(each_edge).get('tmp', False): + remove_edge_list.append(each_edge) + self.root_hg.remove_edges(remove_edge_list) + + adj_node_list = list(subhg.nodes) + for each_node in subhg.nodes: + if self.root_hg.degree(each_node) == 0: + self.root_hg.remove_node(each_node) + adj_node_list.remove(each_node) + + if len(adj_node_list) != 1 and not remove_nodes: + self.root_hg.add_edge(set(adj_node_list), attr_dict=dict(tmp=True)) + ''' + else: + for each_node_1, each_node_2 in combinations(adj_node_list, 2): + if not self.root_hg.is_adj(each_node_1, each_node_2): + self.root_hg.add_edge( + [each_node_1, each_node_2], attr_dict=dict(tmp=True)) + ''' + subhg.remove_edges_with_attr({'tmp':True}) + self.insert_subhg(subhg) + + +def _get_min_deg_node(hg, ident_node_dict: dict, mode='mol'): + if mode == 'standard': + degree_dict = hg.degrees() + min_deg_node = min(degree_dict, key=degree_dict.get) + min_deg_subhg = hg.adj_subhg(min_deg_node, ident_node_dict) + return min_deg_node, min_deg_subhg + elif mode == 'mol': + degree_dict = hg.degrees() + min_deg = min(degree_dict.values()) + min_deg_node_list = [each_node for each_node in hg.nodes if degree_dict[each_node]==min_deg] + min_deg_subhg_list = [hg.adj_subhg(each_min_deg_node, ident_node_dict) + for each_min_deg_node in min_deg_node_list] + best_score = np.inf + best_idx = -1 + for each_idx in range(len(min_deg_subhg_list)): + if min_deg_subhg_list[each_idx].num_nodes < best_score: + best_idx = each_idx + return min_deg_node_list[each_idx], min_deg_subhg_list[each_idx] + else: + raise ValueError + + +def tree_decomposition(hg, irredundant=True): + """ compute a tree decomposition of the input hypergraph + + Parameters + ---------- + hg : Hypergraph + hypergraph to be decomposed + irredundant : bool + if True, irredundant tree decomposition will be computed. + + Returns + ------- + clique_tree : nx.Graph + each node contains a subhypergraph of `hg` + """ + org_hg = hg.copy() + ident_node_dict = hg.get_identical_node_dict() + clique_tree = CliqueTree(org_hg) + clique_tree.add_node(0, subhg=org_hg) + while True: + degree_dict = org_hg.degrees() + min_deg_node = min(degree_dict, key=degree_dict.get) + min_deg_subhg = org_hg.adj_subhg(min_deg_node, ident_node_dict) + if org_hg.nodes == min_deg_subhg.nodes: + break + + # org_hg and min_deg_subhg are divided + clique_tree.node_update(min_deg_node, min_deg_subhg) + + clique_tree.root_hg.remove_edges_with_attr({'tmp' : True}) + + if irredundant: + clique_tree.to_irredundant() + return clique_tree + + +def tree_decomposition_with_hrg(hg, hrg, irredundant=True, return_root=False): + ''' compute a tree decomposition given a hyperedge replacement grammar. + the resultant clique tree should induce a less compact HRG. + + Parameters + ---------- + hg : Hypergraph + hypergraph to be decomposed + hrg : HyperedgeReplacementGrammar + current HRG + irredundant : bool + if True, irredundant tree decomposition will be computed. + + Returns + ------- + clique_tree : nx.Graph + each node contains a subhypergraph of `hg` + ''' + org_hg = hg.copy() + ident_node_dict = hg.get_identical_node_dict() + clique_tree = CliqueTree(org_hg) + clique_tree.add_node(0, subhg=org_hg) + root_node = 0 + + # construct a clique tree using HRG + success_any = True + while success_any: + success_any = False + for each_prod_rule in hrg.prod_rule_list: + org_hg, success, subhg = each_prod_rule.revert(org_hg, True) + if success: + if each_prod_rule.is_start_rule: root_node = clique_tree.number_of_nodes() + success_any = True + subhg.remove_edges_with_attr({'terminal' : False}) + clique_tree.root_hg = org_hg + clique_tree.insert_subhg(subhg) + + clique_tree.root_hg = org_hg + + for each_edge in deepcopy(org_hg.edges): + if not org_hg.edge_attr(each_edge)['terminal']: + node_list = org_hg.nodes_in_edge(each_edge) + org_hg.remove_edge(each_edge) + + for each_node_1, each_node_2 in combinations(node_list, 2): + if not org_hg.is_adj(each_node_1, each_node_2): + org_hg.add_edge([each_node_1, each_node_2], attr_dict=dict(tmp=True)) + + # construct a clique tree using the existing algorithm + degree_dict = org_hg.degrees() + if degree_dict: + while True: + min_deg_node, min_deg_subhg = _get_min_deg_node(org_hg, ident_node_dict) + if org_hg.nodes == min_deg_subhg.nodes: break + + # org_hg and min_deg_subhg are divided + clique_tree.node_update(min_deg_node, min_deg_subhg) + + clique_tree.root_hg.remove_edges_with_attr({'tmp' : True}) + if irredundant: + clique_tree.to_irredundant() + + if return_root: + if root_node == 0 and 0 not in clique_tree.nodes: + root_node = clique_tree.number_of_nodes() + while root_node not in clique_tree.nodes: + root_node -= 1 + elif root_node not in clique_tree.nodes: + while root_node not in clique_tree.nodes: + root_node -= 1 + else: + pass + return clique_tree, root_node + else: + return clique_tree + + +def tree_decomposition_from_leaf(hg, irredundant=True): + """ compute a tree decomposition of the input hypergraph + + Parameters + ---------- + hg : Hypergraph + hypergraph to be decomposed + irredundant : bool + if True, irredundant tree decomposition will be computed. + + Returns + ------- + clique_tree : nx.Graph + each node contains a subhypergraph of `hg` + """ + def apply_normal_decomposition(clique_tree): + degree_dict = clique_tree.root_hg.degrees() + min_deg_node = min(degree_dict, key=degree_dict.get) + min_deg_subhg = clique_tree.root_hg.adj_subhg(min_deg_node, clique_tree.ident_node_dict) + if clique_tree.root_hg.nodes == min_deg_subhg.nodes: + return clique_tree, False + clique_tree.node_update(min_deg_node, min_deg_subhg) + return clique_tree, True + + def apply_min_edge_deg_decomposition(clique_tree): + edge_degree_dict = clique_tree.root_hg.edge_degrees() + non_tmp_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \ + if not clique_tree.root_hg.edge_attr(each_edge).get('tmp')] + if not non_tmp_edge_list: + return clique_tree, False + min_deg_edge = None + min_deg = np.inf + for each_edge in non_tmp_edge_list: + if min_deg > edge_degree_dict[each_edge]: + min_deg_edge = each_edge + min_deg = edge_degree_dict[each_edge] + node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge) + min_deg_subhg = clique_tree.root_hg.get_subhg( + node_list, [min_deg_edge], clique_tree.ident_node_dict) + if clique_tree.root_hg.nodes == min_deg_subhg.nodes: + return clique_tree, False + clique_tree.update(min_deg_subhg) + return clique_tree, True + + org_hg = hg.copy() + clique_tree = CliqueTree(org_hg) + clique_tree.add_node(0, subhg=org_hg) + + success = True + while success: + clique_tree, success = apply_min_edge_deg_decomposition(clique_tree) + if not success: + clique_tree, success = apply_normal_decomposition(clique_tree) + + clique_tree.root_hg.remove_edges_with_attr({'tmp' : True}) + if irredundant: + clique_tree.to_irredundant() + return clique_tree + +def topological_tree_decomposition( + hg, irredundant=True, rip_labels=True, shrink_cycle=False, contract_cycles=False): + ''' compute a tree decomposition of the input hypergraph + + Parameters + ---------- + hg : Hypergraph + hypergraph to be decomposed + irredundant : bool + if True, irredundant tree decomposition will be computed. + + Returns + ------- + clique_tree : CliqueTree + each node contains a subhypergraph of `hg` + ''' + def _contract_tree(clique_tree): + ''' contract a single leaf + + Parameters + ---------- + clique_tree : CliqueTree + + Returns + ------- + CliqueTree, bool + bool represents whether this operation succeeds or not. + ''' + edge_degree_dict = clique_tree.root_hg.edge_degrees() + leaf_edge_list = [each_edge for each_edge in clique_tree.root_hg.edges \ + if (not clique_tree.root_hg.edge_attr(each_edge).get('tmp'))\ + and edge_degree_dict[each_edge] == 1] + if not leaf_edge_list: + return clique_tree, False + min_deg_edge = leaf_edge_list[0] + node_list = clique_tree.root_hg.nodes_in_edge(min_deg_edge) + min_deg_subhg = clique_tree.root_hg.get_subhg( + node_list, [min_deg_edge], clique_tree.ident_node_dict) + if clique_tree.root_hg.nodes == min_deg_subhg.nodes: + return clique_tree, False + clique_tree.update(min_deg_subhg) + return clique_tree, True + + def _rip_labels_from_cycles(clique_tree, org_hg): + ''' rip hyperedge-labels off + + Parameters + ---------- + clique_tree : CliqueTree + org_hg : Hypergraph + + Returns + ------- + CliqueTree, bool + bool represents whether this operation succeeds or not. + ''' + ident_node_dict = clique_tree.ident_node_dict #hg.get_identical_node_dict() + for each_edge in clique_tree.root_hg.edges: + if each_edge in org_hg.edges: + if org_hg.in_cycle(each_edge): + node_list = clique_tree.root_hg.nodes_in_edge(each_edge) + subhg = clique_tree.root_hg.get_subhg( + node_list, [each_edge], ident_node_dict) + if clique_tree.root_hg.nodes == subhg.nodes: + return clique_tree, False + clique_tree.update(subhg) + ''' + in_cycle_dict = {each_node: org_hg.node_attr(each_node)['is_in_ring'] for each_node in node_list} + if not all(in_cycle_dict.values()): + node_not_in_cycle = [each_node for each_node in in_cycle_dict.keys() if not in_cycle_dict[each_node]][0] + node_list = [node_not_in_cycle] + node_list.extend(clique_tree.root_hg.adj_nodes(node_not_in_cycle)) + edge_list = clique_tree.root_hg.adj_edges(node_not_in_cycle) + import pdb; pdb.set_trace() + subhg = clique_tree.root_hg.get_subhg( + node_list, edge_list, ident_node_dict) + + clique_tree.update(subhg) + ''' + return clique_tree, True + return clique_tree, False + + def _shrink_cycle(clique_tree): + ''' shrink a cycle + + Parameters + ---------- + clique_tree : CliqueTree + + Returns + ------- + CliqueTree, bool + bool represents whether this operation succeeds or not. + ''' + def filter_subhg(subhg, hg, key_node): + num_nodes_cycle = 0 + nodes_in_cycle_list = [] + for each_node in subhg.nodes: + if hg.in_cycle(each_node): + num_nodes_cycle += 1 + if each_node != key_node: + nodes_in_cycle_list.append(each_node) + if num_nodes_cycle > 3: + break + if num_nodes_cycle != 3: + return False + else: + for each_edge in hg.edges: + if set(nodes_in_cycle_list).issubset(hg.nodes_in_edge(each_edge)): + return False + return True + + #ident_node_dict = hg.get_identical_node_dict() + ident_node_dict = clique_tree.ident_node_dict + for each_node in clique_tree.root_hg.nodes: + if clique_tree.root_hg.in_cycle(each_node)\ + and filter_subhg(clique_tree.root_hg.adj_subhg(each_node, ident_node_dict), + clique_tree.root_hg, + each_node): + target_node = each_node + target_subhg = clique_tree.root_hg.adj_subhg(target_node, ident_node_dict) + if clique_tree.root_hg.nodes == target_subhg.nodes: + return clique_tree, False + clique_tree.update(target_subhg) + return clique_tree, True + return clique_tree, False + + def _contract_cycles(clique_tree): + ''' + remove a subhypergraph that looks like a cycle on a leaf. + + Parameters + ---------- + clique_tree : CliqueTree + + Returns + ------- + CliqueTree, bool + bool represents whether this operation succeeds or not. + ''' + def _divide_hg(hg): + ''' divide a hypergraph into subhypergraphs such that + each subhypergraph is connected to each other in a tree-like way. + + Parameters + ---------- + hg : Hypergraph + + Returns + ------- + list of Hypergraphs + each element corresponds to a subhypergraph of `hg` + ''' + for each_node in hg.nodes: + if hg.is_dividable(each_node): + adj_edges_dict = {each_edge: hg.in_cycle(each_edge) for each_edge in hg.adj_edges(each_node)} + ''' + if any(adj_edges_dict.values()): + import pdb; pdb.set_trace() + edge_in_cycle = [each_key for each_key, each_val in adj_edges_dict.items() if each_val][0] + subhg1, subhg2, subhg3 = hg.divide(each_node, edge_in_cycle) + return _divide_hg(subhg1) + _divide_hg(subhg2) + _divide_hg(subhg3) + else: + ''' + subhg1, subhg2 = hg.divide(each_node) + return _divide_hg(subhg1) + _divide_hg(subhg2) + return [hg] + + def _is_leaf(hg, divided_subhg) -> bool: + ''' judge whether subhg is a leaf-like in the original hypergraph + + Parameters + ---------- + hg : Hypergraph + divided_subhg : Hypergraph + `divided_subhg` is a subhypergraph of `hg` + + Returns + ------- + bool + ''' + ''' + adj_edges_set = set([]) + for each_node in divided_subhg.nodes: + adj_edges_set.update(set(hg.adj_edges(each_node))) + + + _hg = deepcopy(hg) + _hg.remove_subhg(divided_subhg) + if nx.is_connected(_hg.hg) != (len(adj_edges_set - divided_subhg.edges) == 1): + import pdb; pdb.set_trace() + return len(adj_edges_set - divided_subhg.edges) == 1 + ''' + _hg = deepcopy(hg) + _hg.remove_subhg(divided_subhg) + return nx.is_connected(_hg.hg) + + subhg_list = _divide_hg(clique_tree.root_hg) + if len(subhg_list) == 1: + return clique_tree, False + else: + while len(subhg_list) > 1: + max_leaf_subhg = None + for each_subhg in subhg_list: + if _is_leaf(clique_tree.root_hg, each_subhg): + if max_leaf_subhg is None: + max_leaf_subhg = each_subhg + elif max_leaf_subhg.num_nodes < each_subhg.num_nodes: + max_leaf_subhg = each_subhg + clique_tree.update(max_leaf_subhg) + subhg_list.remove(max_leaf_subhg) + return clique_tree, True + + org_hg = hg.copy() + clique_tree = CliqueTree(org_hg) + clique_tree.add_node(0, subhg=org_hg) + + success = True + while success: + ''' + clique_tree, success = _rip_labels_from_cycles(clique_tree, hg) + if not success: + clique_tree, success = _contract_cycles(clique_tree) + ''' + clique_tree, success = _contract_tree(clique_tree) + if not success: + if rip_labels: + clique_tree, success = _rip_labels_from_cycles(clique_tree, hg) + if not success: + if shrink_cycle: + clique_tree, success = _shrink_cycle(clique_tree) + if not success: + if contract_cycles: + clique_tree, success = _contract_cycles(clique_tree) + clique_tree.root_hg.remove_edges_with_attr({'tmp' : True}) + if irredundant: + clique_tree.to_irredundant() + return clique_tree + +def molecular_tree_decomposition(hg, irredundant=True): + """ compute a tree decomposition of the input molecular hypergraph + + Parameters + ---------- + hg : Hypergraph + molecular hypergraph to be decomposed + irredundant : bool + if True, irredundant tree decomposition will be computed. + + Returns + ------- + clique_tree : CliqueTree + each node contains a subhypergraph of `hg` + """ + def _divide_hg(hg): + ''' divide a hypergraph into subhypergraphs such that + each subhypergraph is connected to each other in a tree-like way. + + Parameters + ---------- + hg : Hypergraph + + Returns + ------- + list of Hypergraphs + each element corresponds to a subhypergraph of `hg` + ''' + is_ring = False + for each_node in hg.nodes: + if hg.node_attr(each_node)['is_in_ring']: + is_ring = True + if not hg.node_attr(each_node)['is_in_ring'] \ + and hg.degree(each_node) == 2: + subhg1, subhg2 = hg.divide(each_node) + return _divide_hg(subhg1) + _divide_hg(subhg2) + + if is_ring: + subhg_list = [] + remove_edge_list = [] + remove_node_list = [] + for each_edge in hg.edges: + node_list = hg.nodes_in_edge(each_edge) + subhg = hg.get_subhg(node_list, [each_edge], hg.get_identical_node_dict()) + subhg_list.append(subhg) + remove_edge_list.append(each_edge) + for each_node in node_list: + if not hg.node_attr(each_node)['is_in_ring']: + remove_node_list.append(each_node) + hg.remove_edges(remove_edge_list) + hg.remove_nodes(remove_node_list, False) + return subhg_list + [hg] + else: + return [hg] + + org_hg = hg.copy() + clique_tree = CliqueTree(org_hg) + clique_tree.add_node(0, subhg=org_hg) + + subhg_list = _divide_hg(deepcopy(clique_tree.root_hg)) + #_subhg_list = deepcopy(subhg_list) + if len(subhg_list) == 1: + pass + else: + while len(subhg_list) > 1: + max_leaf_subhg = None + for each_subhg in subhg_list: + if _is_leaf(clique_tree.root_hg, each_subhg) and not _is_ring(each_subhg): + if max_leaf_subhg is None: + max_leaf_subhg = each_subhg + elif max_leaf_subhg.num_nodes < each_subhg.num_nodes: + max_leaf_subhg = each_subhg + + if max_leaf_subhg is None: + for each_subhg in subhg_list: + if _is_ring_label(clique_tree.root_hg, each_subhg): + if max_leaf_subhg is None: + max_leaf_subhg = each_subhg + elif max_leaf_subhg.num_nodes < each_subhg.num_nodes: + max_leaf_subhg = each_subhg + if max_leaf_subhg is not None: + clique_tree.update(max_leaf_subhg) + subhg_list.remove(max_leaf_subhg) + else: + for each_subhg in subhg_list: + if _is_leaf(clique_tree.root_hg, each_subhg): + if max_leaf_subhg is None: + max_leaf_subhg = each_subhg + elif max_leaf_subhg.num_nodes < each_subhg.num_nodes: + max_leaf_subhg = each_subhg + if max_leaf_subhg is not None: + clique_tree.update(max_leaf_subhg, True) + subhg_list.remove(max_leaf_subhg) + else: + break + if len(subhg_list) > 1: + ''' + for each_idx, each_subhg in enumerate(subhg_list): + each_subhg.draw(f'{each_idx}', True) + clique_tree.root_hg.draw('root', True) + import pickle + with open('buggy_hg.pkl', 'wb') as f: + pickle.dump(hg, f) + return clique_tree, subhg_list, _subhg_list + ''' + raise RuntimeError('bug in tree decomposition algorithm') + clique_tree.root_hg.remove_edges_with_attr({'tmp' : True}) + + ''' + for each_tree_node in clique_tree.adj[0]: + subhg = clique_tree.nodes[each_tree_node]['subhg'] + for each_edge in subhg.edges: + if set(subhg.nodes_in_edge(each_edge)).issubset(clique_tree.root_hg.nodes): + clique_tree.root_hg.add_edge(set(subhg.nodes_in_edge(each_edge)), attr_dict=dict(tmp=True)) + ''' + if irredundant: + clique_tree.to_irredundant() + return clique_tree #, _subhg_list + +def _is_leaf(hg, subhg) -> bool: + ''' judge whether subhg is a leaf-like in the original hypergraph + + Parameters + ---------- + hg : Hypergraph + subhg : Hypergraph + `subhg` is a subhypergraph of `hg` + + Returns + ------- + bool + ''' + if len(subhg.edges) == 0: + adj_edge_set = set([]) + subhg_edge_set = set([]) + for each_edge in hg.edges: + if set(hg.nodes_in_edge(each_edge)).issubset(subhg.nodes) and hg.edge_attr(each_edge).get('tmp', False): + subhg_edge_set.add(each_edge) + for each_node in subhg.nodes: + adj_edge_set.update(set(hg.adj_edges(each_node))) + if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1: + return True + else: + return False + elif len(subhg.edges) == 1: + adj_edge_set = set([]) + subhg_edge_set = subhg.edges + for each_node in subhg.nodes: + for each_adj_edge in hg.adj_edges(each_node): + adj_edge_set.add(each_adj_edge) + if subhg_edge_set.issubset(adj_edge_set) and len(adj_edge_set.difference(subhg_edge_set)) == 1: + return True + else: + return False + else: + raise ValueError('subhg should be nodes only or one-edge hypergraph.') + +def _is_ring_label(hg, subhg): + if len(subhg.edges) != 1: + return False + edge_name = list(subhg.edges)[0] + #assert edge_name in hg.edges, f'{edge_name}' + is_in_ring = False + for each_node in subhg.nodes: + if subhg.node_attr(each_node)['is_in_ring']: + is_in_ring = True + else: + adj_edge_list = list(hg.adj_edges(each_node)) + adj_edge_list.remove(edge_name) + if len(adj_edge_list) == 1: + if not hg.edge_attr(adj_edge_list[0]).get('tmp', False): + return False + elif len(adj_edge_list) == 0: + pass + else: + raise ValueError + if is_in_ring: + return True + else: + return False + +def _is_ring(hg): + for each_node in hg.nodes: + if not hg.node_attr(each_node)['is_in_ring']: + return False + return True + diff --git a/graph_grammar/graph_grammar/__init__.py b/graph_grammar/graph_grammar/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85e6131daba8a4f601ae72d37e6eb035d9503045 --- /dev/null +++ b/graph_grammar/graph_grammar/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Jan 1 2018" + diff --git a/graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc b/graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..214041c778104f480f6870e377626ee4c7021c7e Binary files /dev/null and b/graph_grammar/graph_grammar/__pycache__/__init__.cpython-310.pyc differ diff --git a/graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc b/graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c09d5018842a72c2fa5452afbbc4e2d19fdda8a3 Binary files /dev/null and b/graph_grammar/graph_grammar/__pycache__/base.cpython-310.pyc differ diff --git a/graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc b/graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..554af9638f063de67d4c34092aa0aacc2b6d7681 Binary files /dev/null and b/graph_grammar/graph_grammar/__pycache__/corpus.cpython-310.pyc differ diff --git a/graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc b/graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14ca25470c52e7547eba204d4019f21c321558ba Binary files /dev/null and b/graph_grammar/graph_grammar/__pycache__/hrg.cpython-310.pyc differ diff --git a/graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc b/graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..208465ba58169d7ac7b468ebde4ed3953cdec08b Binary files /dev/null and b/graph_grammar/graph_grammar/__pycache__/symbols.cpython-310.pyc differ diff --git a/graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc b/graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5fda6d465da19e2b79cbeaf39bde2d7ac1f4a77 Binary files /dev/null and b/graph_grammar/graph_grammar/__pycache__/utils.cpython-310.pyc differ diff --git a/graph_grammar/graph_grammar/base.py b/graph_grammar/graph_grammar/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c5977dff873dc004b0e1f1fbae1e65af5b52052c --- /dev/null +++ b/graph_grammar/graph_grammar/base.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2017" +__version__ = "0.1" +__date__ = "Dec 11 2017" + +from abc import ABCMeta, abstractmethod + +class GraphGrammarBase(metaclass=ABCMeta): + @abstractmethod + def learn(self): + pass + + @abstractmethod + def sample(self): + pass diff --git a/graph_grammar/graph_grammar/corpus.py b/graph_grammar/graph_grammar/corpus.py new file mode 100644 index 0000000000000000000000000000000000000000..dad81a13d0b4873b2929e32a9031f3105811c808 --- /dev/null +++ b/graph_grammar/graph_grammar/corpus.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Jun 4 2018" + +from collections import Counter +from functools import partial +from .utils import _easy_node_match, _edge_match, _node_match, common_node_list, _node_match_prod_rule +from networkx.algorithms.isomorphism import GraphMatcher +import os + + +class CliqueTreeCorpus(object): + + ''' clique tree corpus + + Attributes + ---------- + clique_tree_list : list of CliqueTree + subhg_list : list of Hypergraph + ''' + + def __init__(self): + self.clique_tree_list = [] + self.subhg_list = [] + + @property + def size(self): + return len(self.subhg_list) + + def add_clique_tree(self, clique_tree): + for each_node in clique_tree.nodes: + subhg = clique_tree.nodes[each_node]['subhg'] + subhg_idx = self.add_subhg(subhg) + clique_tree.nodes[each_node]['subhg_idx'] = subhg_idx + self.clique_tree_list.append(clique_tree) + + def add_to_subhg_list(self, clique_tree, root_node): + parent_node_dict = {} + current_node = None + parent_node_dict[root_node] = None + stack = [root_node] + while stack: + current_node = stack.pop() + current_subhg = clique_tree.nodes[current_node]['subhg'] + for each_child in clique_tree.adj[current_node]: + if each_child != parent_node_dict[current_node]: + stack.append(each_child) + parent_node_dict[each_child] = current_node + if parent_node_dict[current_node] is not None: + parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg'] + common, _ = common_node_list(parent_subhg, current_subhg) + parent_subhg.add_edge(set(common), attr_dict={'tmp': True}) + + parent_node_dict = {} + current_node = None + parent_node_dict[root_node] = None + stack = [root_node] + while stack: + current_node = stack.pop() + current_subhg = clique_tree.nodes[current_node]['subhg'] + for each_child in clique_tree.adj[current_node]: + if each_child != parent_node_dict[current_node]: + stack.append(each_child) + parent_node_dict[each_child] = current_node + if parent_node_dict[current_node] is not None: + parent_subhg = clique_tree.nodes[parent_node_dict[current_node]]['subhg'] + common, _ = common_node_list(parent_subhg, current_subhg) + for each_idx, each_node in enumerate(common): + current_subhg.set_node_attr(each_node, {'ext_id': each_idx}) + + subhg_idx, is_new = self.add_subhg(current_subhg) + clique_tree.nodes[current_node]['subhg_idx'] = subhg_idx + return clique_tree + + def add_subhg(self, subhg): + if len(self.subhg_list) == 0: + node_dict = {} + for each_node in subhg.nodes: + node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__() + node_list = [] + for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]): + node_list.append(each_key) + for each_idx, each_node in enumerate(node_list): + subhg.node_attr(each_node)['order4hrg'] = each_idx + self.subhg_list.append(subhg) + return 0, True + else: + match = False + subhg_bond_symbol_counter \ + = Counter([subhg.node_attr(each_node)['symbol'] \ + for each_node in subhg.nodes]) + subhg_atom_symbol_counter \ + = Counter([subhg.edge_attr(each_edge).get('symbol', None) \ + for each_edge in subhg.edges]) + for each_idx, each_subhg in enumerate(self.subhg_list): + each_bond_symbol_counter \ + = Counter([each_subhg.node_attr(each_node)['symbol'] \ + for each_node in each_subhg.nodes]) + each_atom_symbol_counter \ + = Counter([each_subhg.edge_attr(each_edge).get('symbol', None) \ + for each_edge in each_subhg.edges]) + if not match \ + and (subhg.num_nodes == each_subhg.num_nodes + and subhg.num_edges == each_subhg.num_edges + and subhg_bond_symbol_counter == each_bond_symbol_counter + and subhg_atom_symbol_counter == each_atom_symbol_counter): + gm = GraphMatcher(each_subhg.hg, + subhg.hg, + node_match=_easy_node_match, + edge_match=_edge_match) + try: + isomap = next(gm.isomorphisms_iter()) + match = True + for each_node in each_subhg.nodes: + subhg.node_attr(isomap[each_node])['order4hrg'] \ + = each_subhg.node_attr(each_node)['order4hrg'] + if 'ext_id' in each_subhg.node_attr(each_node): + subhg.node_attr(isomap[each_node])['ext_id'] \ + = each_subhg.node_attr(each_node)['ext_id'] + return each_idx, False + except StopIteration: + match = False + if not match: + node_dict = {} + for each_node in subhg.nodes: + node_dict[each_node] = subhg.node_attr(each_node)['symbol'].__hash__() + node_list = [] + for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]): + node_list.append(each_key) + for each_idx, each_node in enumerate(node_list): + subhg.node_attr(each_node)['order4hrg'] = each_idx + + #for each_idx, each_node in enumerate(subhg.nodes): + # subhg.node_attr(each_node)['order4hrg'] = each_idx + self.subhg_list.append(subhg) + return len(self.subhg_list) - 1, True diff --git a/graph_grammar/graph_grammar/hrg.py b/graph_grammar/graph_grammar/hrg.py new file mode 100644 index 0000000000000000000000000000000000000000..49adf224b06b6b0fbac9040865f46b0c4f20a85a --- /dev/null +++ b/graph_grammar/graph_grammar/hrg.py @@ -0,0 +1,1065 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2017" +__version__ = "0.1" +__date__ = "Dec 11 2017" + +from .corpus import CliqueTreeCorpus +from .base import GraphGrammarBase +from .symbols import TSymbol, NTSymbol, BondSymbol +from .utils import _node_match, _node_match_prod_rule, _edge_match, masked_softmax, common_node_list +from ..hypergraph import Hypergraph +from collections import Counter +from copy import deepcopy +from ..algo.tree_decomposition import ( + tree_decomposition, + tree_decomposition_with_hrg, + tree_decomposition_from_leaf, + topological_tree_decomposition, + molecular_tree_decomposition) +from functools import partial +from networkx.algorithms.isomorphism import GraphMatcher +from typing import List, Dict, Tuple +import networkx as nx +import numpy as np +import torch +import os +import random + +DEBUG = False + + +class ProductionRule(object): + """ A class of a production rule + + Attributes + ---------- + lhs : Hypergraph or None + the left hand side of the production rule. + if None, the rule is a starting rule. + rhs : Hypergraph + the right hand side of the production rule. + """ + def __init__(self, lhs, rhs): + self.lhs = lhs + self.rhs = rhs + + @property + def is_start_rule(self) -> bool: + return self.lhs.num_nodes == 0 + + @property + def ext_node(self) -> Dict[int, str]: + """ return a dict of external nodes + """ + if self.is_start_rule: + return {} + else: + ext_node_dict = {} + for each_node in self.lhs.nodes: + ext_node_dict[self.lhs.node_attr(each_node)["ext_id"]] = each_node + return ext_node_dict + + @property + def lhs_nt_symbol(self) -> NTSymbol: + if self.is_start_rule: + return NTSymbol(degree=0, is_aromatic=False, bond_symbol_list=[]) + else: + return self.lhs.edge_attr(list(self.lhs.edges)[0])['symbol'] + + def rhs_adj_mat(self, node_edge_list): + ''' return the adjacency matrix of rhs of the production rule + ''' + return nx.adjacency_matrix(self.rhs.hg, node_edge_list) + + def draw(self, file_path=None): + return self.rhs.draw(file_path) + + def is_same(self, prod_rule, ignore_order=False): + """ judge whether this production rule is + the same as the input one, `prod_rule` + + Parameters + ---------- + prod_rule : ProductionRule + production rule to be compared + + Returns + ------- + is_same : bool + isomap : dict + isomorphism of nodes and hyperedges. + ex) {'bond_42': 'bond_37', 'bond_2': 'bond_1', + 'e36': 'e11', 'e16': 'e12', 'e25': 'e18', + 'bond_40': 'bond_38', 'e26': 'e21', 'bond_41': 'bond_39'}. + key comes from `prod_rule`, value comes from `self`. + """ + if self.is_start_rule: + if not prod_rule.is_start_rule: + return False, {} + else: + if prod_rule.is_start_rule: + return False, {} + else: + if prod_rule.lhs.num_nodes != self.lhs.num_nodes: + return False, {} + + if prod_rule.rhs.num_nodes != self.rhs.num_nodes: + return False, {} + if prod_rule.rhs.num_edges != self.rhs.num_edges: + return False, {} + + subhg_bond_symbol_counter \ + = Counter([prod_rule.rhs.node_attr(each_node)['symbol'] \ + for each_node in prod_rule.rhs.nodes]) + each_bond_symbol_counter \ + = Counter([self.rhs.node_attr(each_node)['symbol'] \ + for each_node in self.rhs.nodes]) + if subhg_bond_symbol_counter != each_bond_symbol_counter: + return False, {} + + subhg_atom_symbol_counter \ + = Counter([prod_rule.rhs.edge_attr(each_edge)['symbol'] \ + for each_edge in prod_rule.rhs.edges]) + each_atom_symbol_counter \ + = Counter([self.rhs.edge_attr(each_edge)['symbol'] \ + for each_edge in self.rhs.edges]) + if subhg_atom_symbol_counter != each_atom_symbol_counter: + return False, {} + + gm = GraphMatcher(prod_rule.rhs.hg, + self.rhs.hg, + partial(_node_match_prod_rule, + ignore_order=ignore_order), + partial(_edge_match, + ignore_order=ignore_order)) + try: + return True, next(gm.isomorphisms_iter()) + except StopIteration: + return False, {} + + def applied_to(self, + hg: Hypergraph, + edge: str) -> Tuple[Hypergraph, List[str]]: + """ augment `hg` by replacing `edge` with `self.rhs`. + + Parameters + ---------- + hg : Hypergraph + edge : str + `edge` must belong to `hg` + + Returns + ------- + hg : Hypergraph + resultant hypergraph + nt_edge_list : list + list of non-terminal edges + """ + nt_edge_dict = {} + if self.is_start_rule: + if (edge is not None) or (hg is not None): + ValueError("edge and hg must be None for this prod rule.") + hg = Hypergraph() + node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented. + for num_idx, each_node in enumerate(self.rhs.nodes): + hg.add_node(f"bond_{num_idx}", + #attr_dict=deepcopy(self.rhs.node_attr(each_node))) + attr_dict=self.rhs.node_attr(each_node)) + node_map_rhs[each_node] = f"bond_{num_idx}" + for each_edge in self.rhs.edges: + node_list = [] + for each_node in self.rhs.nodes_in_edge(each_edge): + node_list.append(node_map_rhs[each_node]) + if isinstance(self.rhs.nodes_in_edge(each_edge), set): + node_list = set(node_list) + edge_id = hg.add_edge( + node_list, + #attr_dict=deepcopy(self.rhs.edge_attr(each_edge))) + attr_dict=self.rhs.edge_attr(each_edge)) + if "nt_idx" in hg.edge_attr(edge_id): + nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id + nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))] + return hg, nt_edge_list + else: + if edge not in hg.edges: + raise ValueError("the input hyperedge does not exist.") + if hg.edge_attr(edge)["terminal"]: + raise ValueError("the input hyperedge is terminal.") + if hg.edge_attr(edge)['symbol'] != self.lhs_nt_symbol: + print(hg.edge_attr(edge)['symbol'], self.lhs_nt_symbol) + raise ValueError("the input hyperedge and lhs have inconsistent number of nodes.") + if DEBUG: + for node_idx, each_node in enumerate(hg.nodes_in_edge(edge)): + other_node = self.lhs.nodes_in_edge(list(self.lhs.edges)[0])[node_idx] + attr = deepcopy(self.lhs.node_attr(other_node)) + attr.pop('ext_id') + if hg.node_attr(each_node) != attr: + raise ValueError('node attributes are inconsistent.') + + # order of nodes that belong to the non-terminal edge in hg + nt_order_dict = {} # hg_node -> order ("bond_17" : 1) + nt_order_dict_inv = {} # order -> hg_node + for each_idx, each_node in enumerate(hg.nodes_in_edge(edge)): + nt_order_dict[each_node] = each_idx + nt_order_dict_inv[each_idx] = each_node + + # construct a node_map_rhs: rhs -> new hg + node_map_rhs = {} # node id in rhs -> node id in hg, where rhs is augmented. + node_idx = hg.num_nodes + for each_node in self.rhs.nodes: + if "ext_id" in self.rhs.node_attr(each_node): + node_map_rhs[each_node] \ + = nt_order_dict_inv[ + self.rhs.node_attr(each_node)["ext_id"]] + else: + node_map_rhs[each_node] = f"bond_{node_idx}" + node_idx += 1 + + # delete non-terminal + hg.remove_edge(edge) + + # add nodes to hg + for each_node in self.rhs.nodes: + hg.add_node(node_map_rhs[each_node], + attr_dict=self.rhs.node_attr(each_node)) + + # add hyperedges to hg + for each_edge in self.rhs.edges: + node_list_hg = [] + for each_node in self.rhs.nodes_in_edge(each_edge): + node_list_hg.append(node_map_rhs[each_node]) + edge_id = hg.add_edge( + node_list_hg, + attr_dict=self.rhs.edge_attr(each_edge))#deepcopy(self.rhs.edge_attr(each_edge))) + if "nt_idx" in hg.edge_attr(edge_id): + nt_edge_dict[hg.edge_attr(edge_id)["nt_idx"]] = edge_id + nt_edge_list = [nt_edge_dict[key] for key in range(len(nt_edge_dict))] + return hg, nt_edge_list + + def revert(self, hg: Hypergraph, return_subhg=False): + ''' revert applying this production rule. + i.e., if there exists a subhypergraph that matches the r.h.s. of this production rule, + this method replaces the subhypergraph with a non-terminal hyperedge. + + Parameters + ---------- + hg : Hypergraph + hypergraph to be reverted + return_subhg : bool + if True, the removed subhypergraph will be returned. + + Returns + ------- + hg : Hypergraph + the resultant hypergraph. if it cannot be reverted, the original one is returned without any replacement. + success : bool + this indicates whether reverting is successed or not. + ''' + gm = GraphMatcher(hg.hg, self.rhs.hg, node_match=_node_match_prod_rule, + edge_match=_edge_match) + try: + # in case when the matched subhg is connected to the other part via external nodes and more. + not_iso = True + while not_iso: + isomap = next(gm.subgraph_isomorphisms_iter()) + adj_node_set = set([]) # reachable nodes from the internal nodes + subhg_node_set = set(isomap.keys()) # nodes in subhg + for each_node in subhg_node_set: + adj_node_set.add(each_node) + if isomap[each_node] not in self.ext_node.values(): + adj_node_set.update(hg.hg.adj[each_node]) + if adj_node_set == subhg_node_set: + not_iso = False + else: + if return_subhg: + return hg, False, Hypergraph() + else: + return hg, False + inv_isomap = {v: k for k, v in isomap.items()} + ''' + isomap = {'e35': 'e8', 'bond_13': 'bond_18', 'bond_14': 'bond_19', + 'bond_15': 'bond_17', 'e29': 'e23', 'bond_12': 'bond_20'} + where keys come from `hg` and values come from `self.rhs` + ''' + except StopIteration: + if return_subhg: + return hg, False, Hypergraph() + else: + return hg, False + + if return_subhg: + subhg = Hypergraph() + for each_node in hg.nodes: + if each_node in isomap: + subhg.add_node(each_node, attr_dict=hg.node_attr(each_node)) + for each_edge in hg.edges: + if each_edge in isomap: + subhg.add_edge(hg.nodes_in_edge(each_edge), + attr_dict=hg.edge_attr(each_edge), + edge_name=each_edge) + subhg.edge_idx = hg.edge_idx + + # remove subhg except for the externael nodes + for each_key, each_val in isomap.items(): + if each_key.startswith('e'): + hg.remove_edge(each_key) + for each_key, each_val in isomap.items(): + if each_key.startswith('bond_'): + if each_val not in self.ext_node.values(): + hg.remove_node(each_key) + + # add non-terminal hyperedge + nt_node_list = [] + for each_ext_id in self.ext_node.keys(): + nt_node_list.append(inv_isomap[self.ext_node[each_ext_id]]) + + hg.add_edge(nt_node_list, + attr_dict=dict( + terminal=False, + symbol=self.lhs_nt_symbol)) + if return_subhg: + return hg, True, subhg + else: + return hg, True + + +class ProductionRuleCorpus(object): + + ''' + A corpus of production rules. + This class maintains + (i) list of unique production rules, + (ii) list of unique edge symbols (both terminal and non-terminal), and + (iii) list of unique node symbols. + + Attributes + ---------- + prod_rule_list : list + list of unique production rules + edge_symbol_list : list + list of unique symbols (including both terminal and non-terminal) + node_symbol_list : list + list of node symbols + nt_symbol_list : list + list of unique lhs symbols + ext_id_list : list + list of ext_ids + lhs_in_prod_rule : array + a matrix of lhs vs prod_rule (= lhs_in_prod_rule) + ''' + + def __init__(self): + self.prod_rule_list = [] + self.edge_symbol_list = [] + self.edge_symbol_dict = {} + self.node_symbol_list = [] + self.node_symbol_dict = {} + self.nt_symbol_list = [] + self.ext_id_list = [] + self._lhs_in_prod_rule = None + self.lhs_in_prod_rule_row_list = [] + self.lhs_in_prod_rule_col_list = [] + + @property + def lhs_in_prod_rule(self): + if self._lhs_in_prod_rule is None: + self._lhs_in_prod_rule = torch.sparse.FloatTensor( + torch.LongTensor(list(zip(self.lhs_in_prod_rule_row_list, self.lhs_in_prod_rule_col_list))).t(), + torch.FloatTensor([1.0]*len(self.lhs_in_prod_rule_col_list)), + torch.Size([len(self.nt_symbol_list), len(self.prod_rule_list)]) + ).to_dense() + return self._lhs_in_prod_rule + + @property + def num_prod_rule(self): + ''' return the number of production rules + + Returns + ------- + int : the number of unique production rules + ''' + return len(self.prod_rule_list) + + @property + def start_rule_list(self): + ''' return a list of start rules + + Returns + ------- + list : list of start rules + ''' + start_rule_list = [] + for each_prod_rule in self.prod_rule_list: + if each_prod_rule.is_start_rule: + start_rule_list.append(each_prod_rule) + return start_rule_list + + @property + def num_edge_symbol(self): + return len(self.edge_symbol_list) + + @property + def num_node_symbol(self): + return len(self.node_symbol_list) + + @property + def num_ext_id(self): + return len(self.ext_id_list) + + def construct_feature_vectors(self): + ''' this method constructs feature vectors for the production rules collected so far. + currently, NTSymbol and TSymbol are treated in the same manner. + ''' + feature_id_dict = {} + feature_id_dict['TSymbol'] = 0 + feature_id_dict['NTSymbol'] = 1 + feature_id_dict['BondSymbol'] = 2 + for each_edge_symbol in self.edge_symbol_list: + for each_attr in each_edge_symbol.__dict__.keys(): + each_val = each_edge_symbol.__dict__[each_attr] + if isinstance(each_val, list): + each_val = tuple(each_val) + if (each_attr, each_val) not in feature_id_dict: + feature_id_dict[(each_attr, each_val)] = len(feature_id_dict) + + for each_node_symbol in self.node_symbol_list: + for each_attr in each_node_symbol.__dict__.keys(): + each_val = each_node_symbol.__dict__[each_attr] + if isinstance(each_val, list): + each_val = tuple(each_val) + if (each_attr, each_val) not in feature_id_dict: + feature_id_dict[(each_attr, each_val)] = len(feature_id_dict) + for each_ext_id in self.ext_id_list: + feature_id_dict[('ext_id', each_ext_id)] = len(feature_id_dict) + dim = len(feature_id_dict) + + feature_dict = {} + for each_edge_symbol in self.edge_symbol_list: + idx_list = [] + idx_list.append(feature_id_dict[each_edge_symbol.__class__.__name__]) + for each_attr in each_edge_symbol.__dict__.keys(): + each_val = each_edge_symbol.__dict__[each_attr] + if isinstance(each_val, list): + each_val = tuple(each_val) + idx_list.append(feature_id_dict[(each_attr, each_val)]) + feature = torch.sparse.LongTensor( + torch.LongTensor([idx_list]), + torch.ones(len(idx_list)), + torch.Size([len(feature_id_dict)]) + ) + feature_dict[each_edge_symbol] = feature + + for each_node_symbol in self.node_symbol_list: + idx_list = [] + idx_list.append(feature_id_dict[each_node_symbol.__class__.__name__]) + for each_attr in each_node_symbol.__dict__.keys(): + each_val = each_node_symbol.__dict__[each_attr] + if isinstance(each_val, list): + each_val = tuple(each_val) + idx_list.append(feature_id_dict[(each_attr, each_val)]) + feature = torch.sparse.LongTensor( + torch.LongTensor([idx_list]), + torch.ones(len(idx_list)), + torch.Size([len(feature_id_dict)]) + ) + feature_dict[each_node_symbol] = feature + for each_ext_id in self.ext_id_list: + idx_list = [feature_id_dict[('ext_id', each_ext_id)]] + feature_dict[('ext_id', each_ext_id)] \ + = torch.sparse.LongTensor( + torch.LongTensor([idx_list]), + torch.ones(len(idx_list)), + torch.Size([len(feature_id_dict)]) + ) + return feature_dict, dim + + def edge_symbol_idx(self, symbol): + return self.edge_symbol_dict[symbol] + + def node_symbol_idx(self, symbol): + return self.node_symbol_dict[symbol] + + def append(self, prod_rule: ProductionRule) -> Tuple[int, ProductionRule]: + """ return whether the input production rule is new or not, and its production rule id. + Production rules are regarded as the same if + i) there exists a one-to-one mapping of nodes and edges, and + ii) all the attributes associated with nodes and hyperedges are the same. + + Parameters + ---------- + prod_rule : ProductionRule + + Returns + ------- + prod_rule_id : int + production rule index. if new, a new index will be assigned. + prod_rule : ProductionRule + """ + num_lhs = len(self.nt_symbol_list) + for each_idx, each_prod_rule in enumerate(self.prod_rule_list): + is_same, isomap = prod_rule.is_same(each_prod_rule) + if is_same: + # we do not care about edge and node names, but care about the order of non-terminal edges. + for key, val in isomap.items(): # key : edges & nodes in each_prod_rule.rhs , val : those in prod_rule.rhs + if key.startswith("bond_"): + continue + + # rewrite `nt_idx` in `prod_rule` for further processing + if "nt_idx" in prod_rule.rhs.edge_attr(val).keys(): + if "nt_idx" not in each_prod_rule.rhs.edge_attr(key).keys(): + raise ValueError + prod_rule.rhs.set_edge_attr( + val, + {'nt_idx': each_prod_rule.rhs.edge_attr(key)["nt_idx"]}) + return each_idx, prod_rule + self.prod_rule_list.append(prod_rule) + self._update_edge_symbol_list(prod_rule) + self._update_node_symbol_list(prod_rule) + self._update_ext_id_list(prod_rule) + + lhs_idx = self.nt_symbol_list.index(prod_rule.lhs_nt_symbol) + self.lhs_in_prod_rule_row_list.append(lhs_idx) + self.lhs_in_prod_rule_col_list.append(len(self.prod_rule_list)-1) + self._lhs_in_prod_rule = None + return len(self.prod_rule_list)-1, prod_rule + + def get_prod_rule(self, prod_rule_idx: int) -> ProductionRule: + return self.prod_rule_list[prod_rule_idx] + + def sample(self, unmasked_logit_array, nt_symbol, deterministic=False): + ''' sample a production rule whose lhs is `nt_symbol`, followihng `unmasked_logit_array`. + + Parameters + ---------- + unmasked_logit_array : array-like, length `num_prod_rule` + nt_symbol : NTSymbol + ''' + if not isinstance(unmasked_logit_array, np.ndarray): + unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64) + if deterministic: + prob = masked_softmax(unmasked_logit_array, + self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64)) + return self.prod_rule_list[np.argmax(prob)] + else: + return np.random.choice( + self.prod_rule_list, 1, + p=masked_softmax(unmasked_logit_array, + self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64)))[0] + + def masked_logprob(self, unmasked_logit_array, nt_symbol): + if not isinstance(unmasked_logit_array, np.ndarray): + unmasked_logit_array = unmasked_logit_array.numpy().astype(np.float64) + prob = masked_softmax(unmasked_logit_array, + self.lhs_in_prod_rule[self.nt_symbol_list.index(nt_symbol)].numpy().astype(np.float64)) + return np.log(prob) + + def _update_edge_symbol_list(self, prod_rule: ProductionRule): + ''' update edge symbol list + + Parameters + ---------- + prod_rule : ProductionRule + ''' + if prod_rule.lhs_nt_symbol not in self.nt_symbol_list: + self.nt_symbol_list.append(prod_rule.lhs_nt_symbol) + + for each_edge in prod_rule.rhs.edges: + if prod_rule.rhs.edge_attr(each_edge)['symbol'] not in self.edge_symbol_dict: + edge_symbol_idx = len(self.edge_symbol_list) + self.edge_symbol_list.append(prod_rule.rhs.edge_attr(each_edge)['symbol']) + self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']] = edge_symbol_idx + else: + edge_symbol_idx = self.edge_symbol_dict[prod_rule.rhs.edge_attr(each_edge)['symbol']] + prod_rule.rhs.edge_attr(each_edge)['symbol_idx'] = edge_symbol_idx + pass + + def _update_node_symbol_list(self, prod_rule: ProductionRule): + ''' update node symbol list + + Parameters + ---------- + prod_rule : ProductionRule + ''' + for each_node in prod_rule.rhs.nodes: + if prod_rule.rhs.node_attr(each_node)['symbol'] not in self.node_symbol_dict: + node_symbol_idx = len(self.node_symbol_list) + self.node_symbol_list.append(prod_rule.rhs.node_attr(each_node)['symbol']) + self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']] = node_symbol_idx + else: + node_symbol_idx = self.node_symbol_dict[prod_rule.rhs.node_attr(each_node)['symbol']] + prod_rule.rhs.node_attr(each_node)['symbol_idx'] = node_symbol_idx + + def _update_ext_id_list(self, prod_rule: ProductionRule): + for each_node in prod_rule.rhs.nodes: + if 'ext_id' in prod_rule.rhs.node_attr(each_node): + if prod_rule.rhs.node_attr(each_node)['ext_id'] not in self.ext_id_list: + self.ext_id_list.append(prod_rule.rhs.node_attr(each_node)['ext_id']) + + +class HyperedgeReplacementGrammar(GraphGrammarBase): + """ + Learn a hyperedge replacement grammar from a set of hypergraphs. + + Attributes + ---------- + prod_rule_list : list of ProductionRule + production rules learned from the input hypergraphs + """ + def __init__(self, + tree_decomposition=molecular_tree_decomposition, + ignore_order=False, **kwargs): + from functools import partial + self.prod_rule_corpus = ProductionRuleCorpus() + self.clique_tree_corpus = CliqueTreeCorpus() + self.ignore_order = ignore_order + self.tree_decomposition = partial(tree_decomposition, **kwargs) + + @property + def num_prod_rule(self): + ''' return the number of production rules + + Returns + ------- + int : the number of unique production rules + ''' + return self.prod_rule_corpus.num_prod_rule + + @property + def start_rule_list(self): + ''' return a list of start rules + + Returns + ------- + list : list of start rules + ''' + return self.prod_rule_corpus.start_rule_list + + @property + def prod_rule_list(self): + return self.prod_rule_corpus.prod_rule_list + + def learn(self, hg_list, logger=print, max_mol=np.inf, print_freq=500): + """ learn from a list of hypergraphs + + Parameters + ---------- + hg_list : list of Hypergraph + + Returns + ------- + prod_rule_seq_list : list of integers + each element corresponds to a sequence of production rules to generate each hypergraph. + """ + prod_rule_seq_list = [] + idx = 0 + for each_idx, each_hg in enumerate(hg_list): + clique_tree = self.tree_decomposition(each_hg) + + # get a pair of myself and children + root_node = _find_root(clique_tree) + clique_tree = self.clique_tree_corpus.add_to_subhg_list(clique_tree, root_node) + prod_rule_seq = [] + stack = [] + + children = sorted(list(clique_tree[root_node].keys())) + + # extract a temporary production rule + prod_rule = extract_prod_rule( + None, + clique_tree.nodes[root_node]["subhg"], + [clique_tree.nodes[each_child]["subhg"] + for each_child in children], + clique_tree.nodes[root_node].get('subhg_idx', None)) + + # update the production rule list + prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule) + children = reorder_children(root_node, + children, + prod_rule, + clique_tree) + stack.extend([(root_node, each_child) for each_child in children[::-1]]) + prod_rule_seq.append(prod_rule_id) + + while len(stack) != 0: + # get a triple of parent, myself, and children + parent, myself = stack.pop() + children = sorted(list(dict(clique_tree[myself]).keys())) + children.remove(parent) + + # extract a temp prod rule + prod_rule = extract_prod_rule( + clique_tree.nodes[parent]["subhg"], + clique_tree.nodes[myself]["subhg"], + [clique_tree.nodes[each_child]["subhg"] + for each_child in children], + clique_tree.nodes[myself].get('subhg_idx', None)) + + # update the prod rule list + prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule) + children = reorder_children(myself, + children, + prod_rule, + clique_tree) + stack.extend([(myself, each_child) + for each_child in children[::-1]]) + prod_rule_seq.append(prod_rule_id) + prod_rule_seq_list.append(prod_rule_seq) + if (each_idx+1) % print_freq == 0: + msg = f'#(molecules processed)={each_idx+1}\t'\ + f'#(production rules)={self.prod_rule_corpus.num_prod_rule}\t#(subhg in corpus)={self.clique_tree_corpus.size}' + logger(msg) + if each_idx > max_mol: + break + + print(f'corpus_size = {self.clique_tree_corpus.size}') + return prod_rule_seq_list + + def sample(self, z, deterministic=False): + """ sample a new hypergraph from HRG. + + Parameters + ---------- + z : array-like, shape (len, num_prod_rule) + logit + deterministic : bool + if True, deterministic sampling + + Returns + ------- + Hypergraph + """ + seq_idx = 0 + stack = [] + z = z[:, :-1] + init_prod_rule = self.prod_rule_corpus.sample(z[0], NTSymbol(degree=0, + is_aromatic=False, + bond_symbol_list=[]), + deterministic=deterministic) + hg, nt_edge_list = init_prod_rule.applied_to(None, None) + stack = deepcopy(nt_edge_list[::-1]) + while len(stack) != 0 and seq_idx < z.shape[0]-1: + seq_idx += 1 + nt_edge = stack.pop() + nt_symbol = hg.edge_attr(nt_edge)['symbol'] + prod_rule = self.prod_rule_corpus.sample(z[seq_idx], nt_symbol, deterministic=deterministic) + hg, nt_edge_list = prod_rule.applied_to(hg, nt_edge) + stack.extend(nt_edge_list[::-1]) + if len(stack) != 0: + raise RuntimeError(f'{len(stack)} non-terminals are left.') + return hg + + def construct(self, prod_rule_seq): + """ construct a hypergraph following `prod_rule_seq` + + Parameters + ---------- + prod_rule_seq : list of integers + a sequence of production rules. + + Returns + ------- + UndirectedHypergraph + """ + seq_idx = 0 + init_prod_rule = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx]) + hg, nt_edge_list = init_prod_rule.applied_to(None, None) + stack = deepcopy(nt_edge_list[::-1]) + while len(stack) != 0: + seq_idx += 1 + nt_edge = stack.pop() + hg, nt_edge_list = self.prod_rule_corpus.get_prod_rule(prod_rule_seq[seq_idx]).applied_to(hg, nt_edge) + stack.extend(nt_edge_list[::-1]) + return hg + + def update_prod_rule_list(self, prod_rule): + """ return whether the input production rule is new or not, and its production rule id. + Production rules are regarded as the same if + i) there exists a one-to-one mapping of nodes and edges, and + ii) all the attributes associated with nodes and hyperedges are the same. + + Parameters + ---------- + prod_rule : ProductionRule + + Returns + ------- + is_new : bool + if True, this production rule is new + prod_rule_id : int + production rule index. if new, a new index will be assigned. + """ + return self.prod_rule_corpus.append(prod_rule) + + +class IncrementalHyperedgeReplacementGrammar(HyperedgeReplacementGrammar): + ''' + This class learns HRG incrementally leveraging the previously obtained production rules. + ''' + def __init__(self, tree_decomposition=tree_decomposition_with_hrg, ignore_order=False): + self.prod_rule_list = [] + self.tree_decomposition = tree_decomposition + self.ignore_order = ignore_order + + def learn(self, hg_list): + """ learn from a list of hypergraphs + + Parameters + ---------- + hg_list : list of UndirectedHypergraph + + Returns + ------- + prod_rule_seq_list : list of integers + each element corresponds to a sequence of production rules to generate each hypergraph. + """ + prod_rule_seq_list = [] + for each_hg in hg_list: + clique_tree, root_node = tree_decomposition_with_hrg(each_hg, self, return_root=True) + + prod_rule_seq = [] + stack = [] + + # get a pair of myself and children + children = sorted(list(clique_tree[root_node].keys())) + + # extract a temporary production rule + prod_rule = extract_prod_rule(None, clique_tree.nodes[root_node]["subhg"], + [clique_tree.nodes[each_child]["subhg"] for each_child in children]) + + # update the production rule list + prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule) + children = reorder_children(root_node, children, prod_rule, clique_tree) + stack.extend([(root_node, each_child) for each_child in children[::-1]]) + prod_rule_seq.append(prod_rule_id) + + while len(stack) != 0: + # get a triple of parent, myself, and children + parent, myself = stack.pop() + children = sorted(list(dict(clique_tree[myself]).keys())) + children.remove(parent) + + # extract a temp prod rule + prod_rule = extract_prod_rule( + clique_tree.nodes[parent]["subhg"], clique_tree.nodes[myself]["subhg"], + [clique_tree.nodes[each_child]["subhg"] for each_child in children]) + + # update the prod rule list + prod_rule_id, prod_rule = self.update_prod_rule_list(prod_rule) + children = reorder_children(myself, children, prod_rule, clique_tree) + stack.extend([(myself, each_child) for each_child in children[::-1]]) + prod_rule_seq.append(prod_rule_id) + prod_rule_seq_list.append(prod_rule_seq) + self._compute_stats() + return prod_rule_seq_list + + +def reorder_children(myself, children, prod_rule, clique_tree): + """ reorder children so that they match the order in `prod_rule`. + + Parameters + ---------- + myself : int + children : list of int + prod_rule : ProductionRule + clique_tree : nx.Graph + + Returns + ------- + new_children : list of str + reordered children + """ + perm = {} # key : `nt_idx`, val : child + for each_edge in prod_rule.rhs.edges: + if "nt_idx" in prod_rule.rhs.edge_attr(each_edge).keys(): + for each_child in children: + common_node_set = set( + common_node_list(clique_tree.nodes[myself]["subhg"], + clique_tree.nodes[each_child]["subhg"])[0]) + if set(prod_rule.rhs.nodes_in_edge(each_edge)) == common_node_set: + assert prod_rule.rhs.edge_attr(each_edge)["nt_idx"] not in perm + perm[prod_rule.rhs.edge_attr(each_edge)["nt_idx"]] = each_child + new_children = [] + assert len(perm) == len(children) + for i in range(len(perm)): + new_children.append(perm[i]) + return new_children + + +def extract_prod_rule(parent_hg, myself_hg, children_hg_list, subhg_idx=None): + """ extract a production rule from a triple of `parent_hg`, `myself_hg`, and `children_hg_list`. + + Parameters + ---------- + parent_hg : Hypergraph + myself_hg : Hypergraph + children_hg_list : list of Hypergraph + + Returns + ------- + ProductionRule, consisting of + lhs : Hypergraph or None + rhs : Hypergraph + """ + def _add_ext_node(hg, ext_nodes): + """ mark nodes to be external (ordered ids are assigned) + + Parameters + ---------- + hg : UndirectedHypergraph + ext_nodes : list of str + list of external nodes + + Returns + ------- + hg : Hypergraph + nodes in `ext_nodes` are marked to be external + """ + ext_id = 0 + ext_id_exists = [] + for each_node in ext_nodes: + ext_id_exists.append('ext_id' in hg.node_attr(each_node)) + if ext_id_exists and any(ext_id_exists) != all(ext_id_exists): + raise ValueError + if not all(ext_id_exists): + for each_node in ext_nodes: + hg.node_attr(each_node)['ext_id'] = ext_id + ext_id += 1 + return hg + + def _check_aromatic(hg, node_list): + is_aromatic = False + node_aromatic_list = [] + for each_node in node_list: + if hg.node_attr(each_node)['symbol'].is_aromatic: + is_aromatic = True + node_aromatic_list.append(True) + else: + node_aromatic_list.append(False) + return is_aromatic, node_aromatic_list + + def _check_ring(hg): + for each_edge in hg.edges: + if not ('tmp' in hg.edge_attr(each_edge) or (not hg.edge_attr(each_edge)['terminal'])): + return False + return True + + if parent_hg is None: + lhs = Hypergraph() + node_list = [] + else: + lhs = Hypergraph() + node_list, edge_exists = common_node_list(parent_hg, myself_hg) + for each_node in node_list: + lhs.add_node(each_node, + deepcopy(myself_hg.node_attr(each_node))) + is_aromatic, _ = _check_aromatic(parent_hg, node_list) + for_ring = _check_ring(myself_hg) + bond_symbol_list = [] + for each_node in node_list: + bond_symbol_list.append(parent_hg.node_attr(each_node)['symbol']) + lhs.add_edge( + node_list, + attr_dict=dict( + terminal=False, + edge_exists=edge_exists, + symbol=NTSymbol( + degree=len(node_list), + is_aromatic=is_aromatic, + bond_symbol_list=bond_symbol_list, + for_ring=for_ring))) + try: + lhs = _add_ext_node(lhs, node_list) + except ValueError: + import pdb; pdb.set_trace() + + rhs = remove_tmp_edge(deepcopy(myself_hg)) + #rhs = remove_ext_node(rhs) + #rhs = remove_nt_edge(rhs) + try: + rhs = _add_ext_node(rhs, node_list) + except ValueError: + import pdb; pdb.set_trace() + + nt_idx = 0 + if children_hg_list is not None: + for each_child_hg in children_hg_list: + node_list, edge_exists = common_node_list(myself_hg, each_child_hg) + is_aromatic, _ = _check_aromatic(myself_hg, node_list) + for_ring = _check_ring(each_child_hg) + bond_symbol_list = [] + for each_node in node_list: + bond_symbol_list.append(myself_hg.node_attr(each_node)['symbol']) + rhs.add_edge( + node_list, + attr_dict=dict( + terminal=False, + nt_idx=nt_idx, + edge_exists=edge_exists, + symbol=NTSymbol(degree=len(node_list), + is_aromatic=is_aromatic, + bond_symbol_list=bond_symbol_list, + for_ring=for_ring))) + nt_idx += 1 + prod_rule = ProductionRule(lhs, rhs) + prod_rule.subhg_idx = subhg_idx + if DEBUG: + if sorted(list(prod_rule.ext_node.keys())) \ + != list(np.arange(len(prod_rule.ext_node))): + raise RuntimeError('ext_id is not continuous') + return prod_rule + + +def _find_root(clique_tree): + max_node = None + num_nodes_max = -np.inf + for each_node in clique_tree.nodes: + if clique_tree.nodes[each_node]['subhg'].num_nodes > num_nodes_max: + max_node = each_node + num_nodes_max = clique_tree.nodes[each_node]['subhg'].num_nodes + ''' + children = sorted(list(clique_tree[each_node].keys())) + prod_rule = extract_prod_rule(None, + clique_tree.nodes[each_node]["subhg"], + [clique_tree.nodes[each_child]["subhg"] + for each_child in children]) + for each_start_rule in start_rule_list: + if prod_rule.is_same(each_start_rule): + return each_node + ''' + return max_node + +def remove_ext_node(hg): + for each_node in hg.nodes: + hg.node_attr(each_node).pop('ext_id', None) + return hg + +def remove_nt_edge(hg): + remove_edge_list = [] + for each_edge in hg.edges: + if not hg.edge_attr(each_edge)['terminal']: + remove_edge_list.append(each_edge) + hg.remove_edges(remove_edge_list) + return hg + +def remove_tmp_edge(hg): + remove_edge_list = [] + for each_edge in hg.edges: + if hg.edge_attr(each_edge).get('tmp', False): + remove_edge_list.append(each_edge) + hg.remove_edges(remove_edge_list) + return hg diff --git a/graph_grammar/graph_grammar/symbols.py b/graph_grammar/graph_grammar/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..a024fb263c0aed40f9dc3b816e3c87913594f96c --- /dev/null +++ b/graph_grammar/graph_grammar/symbols.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Jan 1 2018" + +from typing import List + +class TSymbol(object): + + ''' terminal symbol + + Attributes + ---------- + degree : int + the number of nodes in a hyperedge + is_aromatic : bool + whether or not the hyperedge is in an aromatic ring + symbol : str + atomic symbol + num_explicit_Hs : int + the number of hydrogens associated to this hyperedge + formal_charge : int + charge + chirality : int + chirality + ''' + + def __init__(self, degree, is_aromatic, + symbol, num_explicit_Hs, formal_charge, chirality): + self.degree = degree + self.is_aromatic = is_aromatic + self.symbol = symbol + self.num_explicit_Hs = num_explicit_Hs + self.formal_charge = formal_charge + self.chirality = chirality + + @property + def terminal(self): + return True + + def __eq__(self, other): + if not isinstance(other, TSymbol): + return False + if self.degree != other.degree: + return False + if self.is_aromatic != other.is_aromatic: + return False + if self.symbol != other.symbol: + return False + if self.num_explicit_Hs != other.num_explicit_Hs: + return False + if self.formal_charge != other.formal_charge: + return False + if self.chirality != other.chirality: + return False + return True + + def __hash__(self): + return self.__str__().__hash__() + + def __str__(self): + return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\ + f'symbol={self.symbol}, '\ + f'num_explicit_Hs={self.num_explicit_Hs}, '\ + f'formal_charge={self.formal_charge}, chirality={self.chirality}' + + +class NTSymbol(object): + + ''' non-terminal symbol + + Attributes + ---------- + degree : int + degree of the hyperedge + is_aromatic : bool + if True, at least one of the associated bonds must be aromatic. + node_aromatic_list : list of bool + indicate whether each of the nodes is aromatic or not. + bond_type_list : list of int + bond type of each node" + ''' + + def __init__(self, degree: int, is_aromatic: bool, + bond_symbol_list: list, + for_ring=False): + self.degree = degree + self.is_aromatic = is_aromatic + self.for_ring = for_ring + self.bond_symbol_list = bond_symbol_list + + @property + def terminal(self) -> bool: + return False + + @property + def symbol(self): + return f'NT{self.degree}' + + def __eq__(self, other) -> bool: + if not isinstance(other, NTSymbol): + return False + + if self.degree != other.degree: + return False + if self.is_aromatic != other.is_aromatic: + return False + if self.for_ring != other.for_ring: + return False + if len(self.bond_symbol_list) != len(other.bond_symbol_list): + return False + for each_idx in range(len(self.bond_symbol_list)): + if self.bond_symbol_list[each_idx] != other.bond_symbol_list[each_idx]: + return False + return True + + def __hash__(self): + return self.__str__().__hash__() + + def __str__(self) -> str: + return f'degree={self.degree}, is_aromatic={self.is_aromatic}, '\ + f'bond_symbol_list={[str(each_symbol) for each_symbol in self.bond_symbol_list]}'\ + f'for_ring={self.for_ring}' + + +class BondSymbol(object): + + + ''' Bond symbol + + Attributes + ---------- + is_aromatic : bool + if True, at least one of the associated bonds must be aromatic. + bond_type : int + bond type of each node" + ''' + + def __init__(self, is_aromatic: bool, + bond_type: int, + stereo: int): + self.is_aromatic = is_aromatic + self.bond_type = bond_type + self.stereo = stereo + + def __eq__(self, other) -> bool: + if not isinstance(other, BondSymbol): + return False + + if self.is_aromatic != other.is_aromatic: + return False + if self.bond_type != other.bond_type: + return False + if self.stereo != other.stereo: + return False + return True + + def __hash__(self): + return self.__str__().__hash__() + + def __str__(self) -> str: + return f'is_aromatic={self.is_aromatic}, '\ + f'bond_type={self.bond_type}, '\ + f'stereo={self.stereo}, ' diff --git a/graph_grammar/graph_grammar/utils.py b/graph_grammar/graph_grammar/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..18c6bcad4b8ff341f2ffb844dd2265976cca2803 --- /dev/null +++ b/graph_grammar/graph_grammar/utils.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Jun 4 2018" + +from ..hypergraph import Hypergraph +from copy import deepcopy +from typing import List +import numpy as np + + +def common_node_list(hg1: Hypergraph, hg2: Hypergraph) -> List[str]: + """ return a list of common nodes + + Parameters + ---------- + hg1, hg2 : Hypergraph + + Returns + ------- + list of str + list of common nodes + """ + if hg1 is None or hg2 is None: + return [], False + else: + node_set = hg1.nodes.intersection(hg2.nodes) + node_dict = {} + if 'order4hrg' in hg1.node_attr(list(hg1.nodes)[0]): + for each_node in node_set: + node_dict[each_node] = hg1.node_attr(each_node)['order4hrg'] + else: + for each_node in node_set: + node_dict[each_node] = hg1.node_attr(each_node)['symbol'].__hash__() + node_list = [] + for each_key, _ in sorted(node_dict.items(), key=lambda x:x[1]): + node_list.append(each_key) + edge_name = hg1.has_edge(node_list, ignore_order=True) + if edge_name: + if not hg1.edge_attr(edge_name).get('terminal', True): + node_list = hg1.nodes_in_edge(edge_name) + return node_list, True + else: + return node_list, False + + +def _node_match(node1, node2): + # if the nodes are hyperedges, `atom_attr` determines the match + if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge': + return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol'] + elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node': + # bond_symbol + return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol'] + else: + return False + +def _easy_node_match(node1, node2): + # if the nodes are hyperedges, `atom_attr` determines the match + if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge': + return node1["attr_dict"].get('symbol', None) == node2["attr_dict"].get('symbol', None) + elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node': + # bond_symbol + return node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1)\ + and node1['attr_dict']['symbol'] == node2['attr_dict']['symbol'] + else: + return False + + +def _node_match_prod_rule(node1, node2, ignore_order=False): + # if the nodes are hyperedges, `atom_attr` determines the match + if node1['bipartite'] == 'edge' and node2['bipartite'] == 'edge': + return node1["attr_dict"]['symbol'] == node2["attr_dict"]['symbol'] + elif node1['bipartite'] == 'node' and node2['bipartite'] == 'node': + # ext_id, order4hrg, bond_symbol + if ignore_order: + return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol'] + else: + return node1['attr_dict']['symbol'] == node2['attr_dict']['symbol']\ + and node1['attr_dict'].get('ext_id', -1) == node2['attr_dict'].get('ext_id', -1) + else: + return False + + +def _edge_match(edge1, edge2, ignore_order=False): + #return True + if ignore_order: + return True + else: + return edge1["order"] == edge2["order"] + +def masked_softmax(logit, mask): + ''' compute a probability distribution from logit + + Parameters + ---------- + logit : array-like, length D + each element indicates how each dimension is likely to be chosen + (the larger, the more likely) + mask : array-like, length D + each element is either 0 or 1. + if 0, the dimension is ignored + when computing the probability distribution. + + Returns + ------- + prob_dist : array, length D + probability distribution computed from logit. + if `mask[d] = 0`, `prob_dist[d] = 0`. + ''' + if logit.shape != mask.shape: + raise ValueError('logit and mask must have the same shape') + c = np.max(logit) + exp_logit = np.exp(logit - c) * mask + sum_exp_logit = exp_logit @ mask + return exp_logit / sum_exp_logit diff --git a/graph_grammar/hypergraph.py b/graph_grammar/hypergraph.py new file mode 100644 index 0000000000000000000000000000000000000000..15448755e3b40ac0f2b5f80c2b84511a904c2dd1 --- /dev/null +++ b/graph_grammar/hypergraph.py @@ -0,0 +1,544 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Jan 31 2018" + +from copy import deepcopy +from typing import List, Dict, Tuple +import networkx as nx +import numpy as np +import os + + +class Hypergraph(object): + ''' + A class of a hypergraph. + Each hyperedge can be ordered. For the ordered case, + edges adjacent to the hyperedge node are labeled by their orders. + + Attributes + ---------- + hg : nx.Graph + a bipartite graph representation of a hypergraph + edge_idx : int + total number of hyperedges that exist so far + ''' + def __init__(self): + self.hg = nx.Graph() + self.edge_idx = 0 + self.nodes = set([]) + self.num_nodes = 0 + self.edges = set([]) + self.num_edges = 0 + self.nodes_in_edge_dict = {} + + def add_node(self, node: str, attr_dict=None): + ''' add a node to hypergraph + + Parameters + ---------- + node : str + node name + attr_dict : dict + dictionary of node attributes + ''' + self.hg.add_node(node, bipartite='node', attr_dict=attr_dict) + if node not in self.nodes: + self.num_nodes += 1 + self.nodes.add(node) + + def add_edge(self, node_list: List[str], attr_dict=None, edge_name=None): + ''' add an edge consisting of nodes `node_list` + + Parameters + ---------- + node_list : list + ordered list of nodes that consist the edge + attr_dict : dict + dictionary of edge attributes + ''' + if edge_name is None: + edge = 'e{}'.format(self.edge_idx) + else: + assert edge_name not in self.edges + edge = edge_name + self.hg.add_node(edge, bipartite='edge', attr_dict=attr_dict) + if edge not in self.edges: + self.num_edges += 1 + self.edges.add(edge) + self.nodes_in_edge_dict[edge] = node_list + if type(node_list) == list: + for node_idx, each_node in enumerate(node_list): + self.hg.add_edge(edge, each_node, order=node_idx) + if each_node not in self.nodes: + self.num_nodes += 1 + self.nodes.add(each_node) + + elif type(node_list) == set: + for each_node in node_list: + self.hg.add_edge(edge, each_node, order=-1) + if each_node not in self.nodes: + self.num_nodes += 1 + self.nodes.add(each_node) + else: + raise ValueError + self.edge_idx += 1 + return edge + + def remove_node(self, node: str, remove_connected_edges=True): + ''' remove a node + + Parameters + ---------- + node : str + node name + remove_connected_edges : bool + if True, remove edges that are adjacent to the node + ''' + if remove_connected_edges: + connected_edges = deepcopy(self.adj_edges(node)) + for each_edge in connected_edges: + self.remove_edge(each_edge) + self.hg.remove_node(node) + self.num_nodes -= 1 + self.nodes.remove(node) + + def remove_nodes(self, node_iter, remove_connected_edges=True): + ''' remove a set of nodes + + Parameters + ---------- + node_iter : iterator of strings + nodes to be removed + remove_connected_edges : bool + if True, remove edges that are adjacent to the node + ''' + for each_node in node_iter: + self.remove_node(each_node, remove_connected_edges) + + def remove_edge(self, edge: str): + ''' remove an edge + + Parameters + ---------- + edge : str + edge to be removed + ''' + self.hg.remove_node(edge) + self.edges.remove(edge) + self.num_edges -= 1 + self.nodes_in_edge_dict.pop(edge) + + def remove_edges(self, edge_iter): + ''' remove a set of edges + + Parameters + ---------- + edge_iter : iterator of strings + edges to be removed + ''' + for each_edge in edge_iter: + self.remove_edge(each_edge) + + def remove_edges_with_attr(self, edge_attr_dict): + remove_edge_list = [] + for each_edge in self.edges: + satisfy = True + for each_key, each_val in edge_attr_dict.items(): + if not satisfy: + break + try: + if self.edge_attr(each_edge)[each_key] != each_val: + satisfy = False + except KeyError: + satisfy = False + if satisfy: + remove_edge_list.append(each_edge) + self.remove_edges(remove_edge_list) + + def remove_subhg(self, subhg): + ''' remove subhypergraph. + all of the hyperedges are removed. + each node of subhg is removed if its degree becomes 0 after removing hyperedges. + + Parameters + ---------- + subhg : Hypergraph + ''' + for each_edge in subhg.edges: + self.remove_edge(each_edge) + for each_node in subhg.nodes: + if self.degree(each_node) == 0: + self.remove_node(each_node) + + def nodes_in_edge(self, edge): + ''' return an ordered list of nodes in a given edge. + + Parameters + ---------- + edge : str + edge whose nodes are returned + + Returns + ------- + list or set + ordered list or set of nodes that belong to the edge + ''' + if edge.startswith('e'): + return self.nodes_in_edge_dict[edge] + else: + adj_node_list = self.hg.adj[edge] + adj_node_order_list = [] + adj_node_name_list = [] + for each_node in adj_node_list: + adj_node_order_list.append(adj_node_list[each_node]['order']) + adj_node_name_list.append(each_node) + if adj_node_order_list == [-1] * len(adj_node_order_list): + return set(adj_node_name_list) + else: + return [adj_node_name_list[each_idx] for each_idx + in np.argsort(adj_node_order_list)] + + def adj_edges(self, node): + ''' return a dict of adjacent hyperedges + + Parameters + ---------- + node : str + + Returns + ------- + set + set of edges that are adjacent to `node` + ''' + return self.hg.adj[node] + + def adj_nodes(self, node): + ''' return a set of adjacent nodes + + Parameters + ---------- + node : str + + Returns + ------- + set + set of nodes that are adjacent to `node` + ''' + node_set = set([]) + for each_adj_edge in self.adj_edges(node): + node_set.update(set(self.nodes_in_edge(each_adj_edge))) + node_set.discard(node) + return node_set + + def has_edge(self, node_list, ignore_order=False): + for each_edge in self.edges: + if ignore_order: + if set(self.nodes_in_edge(each_edge)) == set(node_list): + return each_edge + else: + if self.nodes_in_edge(each_edge) == node_list: + return each_edge + return False + + def degree(self, node): + return len(self.hg.adj[node]) + + def degrees(self): + return {each_node: self.degree(each_node) for each_node in self.nodes} + + def edge_degree(self, edge): + return len(self.nodes_in_edge(edge)) + + def edge_degrees(self): + return {each_edge: self.edge_degree(each_edge) for each_edge in self.edges} + + def is_adj(self, node1, node2): + return node1 in self.adj_nodes(node2) + + def adj_subhg(self, node, ident_node_dict=None): + """ return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`. + if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph. + + Parameters + ---------- + node : str + ident_node_dict : dict + dict containing identical nodes. see `get_identical_node_dict` for more details + + Returns + ------- + subhg : Hypergraph + """ + if ident_node_dict is None: + ident_node_dict = self.get_identical_node_dict() + adj_node_set = set(ident_node_dict[node]) + adj_edge_set = set([]) + for each_node in ident_node_dict[node]: + adj_edge_set.update(set(self.adj_edges(each_node))) + fixed_adj_edge_set = deepcopy(adj_edge_set) + for each_edge in fixed_adj_edge_set: + other_nodes = self.nodes_in_edge(each_edge) + adj_node_set.update(other_nodes) + + # if the adjacent node has self-loop edge, it will be appended to adj_edge_list. + for each_node in other_nodes: + for other_edge in set(self.adj_edges(each_node)) - set([each_edge]): + if len(set(self.nodes_in_edge(other_edge)) \ + - set(self.nodes_in_edge(each_edge))) == 0: + adj_edge_set.update(set([other_edge])) + subhg = Hypergraph() + for each_node in adj_node_set: + subhg.add_node(each_node, attr_dict=self.node_attr(each_node)) + for each_edge in adj_edge_set: + subhg.add_edge(self.nodes_in_edge(each_edge), + attr_dict=self.edge_attr(each_edge), + edge_name=each_edge) + subhg.edge_idx = self.edge_idx + return subhg + + def get_subhg(self, node_list, edge_list, ident_node_dict=None): + """ return a subhypergraph consisting of a set of nodes and hyperedges adjacent to `node`. + if an adjacent node has a self-loop hyperedge, it will be also added to the subhypergraph. + + Parameters + ---------- + node : str + ident_node_dict : dict + dict containing identical nodes. see `get_identical_node_dict` for more details + + Returns + ------- + subhg : Hypergraph + """ + if ident_node_dict is None: + ident_node_dict = self.get_identical_node_dict() + adj_node_set = set([]) + for each_node in node_list: + adj_node_set.update(set(ident_node_dict[each_node])) + adj_edge_set = set(edge_list) + + subhg = Hypergraph() + for each_node in adj_node_set: + subhg.add_node(each_node, + attr_dict=deepcopy(self.node_attr(each_node))) + for each_edge in adj_edge_set: + subhg.add_edge(self.nodes_in_edge(each_edge), + attr_dict=deepcopy(self.edge_attr(each_edge)), + edge_name=each_edge) + subhg.edge_idx = self.edge_idx + return subhg + + def copy(self): + ''' return a copy of the object + + Returns + ------- + Hypergraph + ''' + return deepcopy(self) + + def node_attr(self, node): + return self.hg.nodes[node]['attr_dict'] + + def edge_attr(self, edge): + return self.hg.nodes[edge]['attr_dict'] + + def set_node_attr(self, node, attr_dict): + for each_key, each_val in attr_dict.items(): + self.hg.nodes[node]['attr_dict'][each_key] = each_val + + def set_edge_attr(self, edge, attr_dict): + for each_key, each_val in attr_dict.items(): + self.hg.nodes[edge]['attr_dict'][each_key] = each_val + + def get_identical_node_dict(self): + ''' get identical nodes + nodes are identical if they share the same set of adjacent edges. + + Returns + ------- + ident_node_dict : dict + ident_node_dict[node] returns a list of nodes that are identical to `node`. + ''' + ident_node_dict = {} + for each_node in self.nodes: + ident_node_list = [] + for each_other_node in self.nodes: + if each_other_node == each_node: + ident_node_list.append(each_other_node) + elif self.adj_edges(each_node) == self.adj_edges(each_other_node) \ + and len(self.adj_edges(each_node)) != 0: + ident_node_list.append(each_other_node) + ident_node_dict[each_node] = ident_node_list + return ident_node_dict + ''' + ident_node_dict = {} + for each_node in self.nodes: + ident_node_dict[each_node] = [each_node] + return ident_node_dict + ''' + + def get_leaf_edge(self): + ''' get an edge that is incident only to one edge + + Returns + ------- + if exists, return a leaf edge. otherwise, return None. + ''' + for each_edge in self.edges: + if len(self.adj_nodes(each_edge)) == 1: + if 'tmp' not in self.edge_attr(each_edge): + return each_edge + return None + + def get_nontmp_edge(self): + for each_edge in self.edges: + if 'tmp' not in self.edge_attr(each_edge): + return each_edge + return None + + def is_subhg(self, hg): + ''' return whether this hypergraph is a subhypergraph of `hg` + + Returns + ------- + True if self \in hg, + False otherwise. + ''' + for each_node in self.nodes: + if each_node not in hg.nodes: + return False + for each_edge in self.edges: + if each_edge not in hg.edges: + return False + return True + + def in_cycle(self, node, visited=None, parent='', root_node='') -> bool: + ''' if `node` is in a cycle, then return True. otherwise, False. + + Parameters + ---------- + node : str + node in a hypergraph + visited : list + list of visited nodes, used for recursion + parent : str + parent node, used to eliminate a cycle consisting of two nodes and one edge. + + Returns + ------- + bool + ''' + if visited is None: + visited = [] + if parent == '': + visited = [] + if root_node == '': + root_node = node + visited.append(node) + for each_adj_node in self.adj_nodes(node): + if each_adj_node not in visited: + if self.in_cycle(each_adj_node, visited, node, root_node): + return True + elif each_adj_node != parent and each_adj_node == root_node: + return True + return False + + + def draw(self, file_path=None, with_node=False, with_edge_name=False): + ''' draw hypergraph + ''' + import graphviz + G = graphviz.Graph(format='png') + for each_node in self.nodes: + if 'ext_id' in self.node_attr(each_node): + G.node(each_node, label='', + shape='circle', width='0.1', height='0.1', style='filled', + fillcolor='black') + else: + if with_node: + G.node(each_node, label='', + shape='circle', width='0.1', height='0.1', style='filled', + fillcolor='gray') + edge_list = [] + for each_edge in self.edges: + if self.edge_attr(each_edge).get('terminal', False): + G.node(each_edge, + label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \ + else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge, + fontcolor='black', shape='square') + elif self.edge_attr(each_edge).get('tmp', False): + G.node(each_edge, label='tmp' if not with_edge_name else 'tmp, ' + each_edge, + fontcolor='black', shape='square') + else: + G.node(each_edge, + label=self.edge_attr(each_edge)['symbol'].symbol if not with_edge_name \ + else self.edge_attr(each_edge)['symbol'].symbol + ', ' + each_edge, + fontcolor='black', shape='square', style='filled') + if with_node: + for each_node in self.nodes_in_edge(each_edge): + G.edge(each_edge, each_node) + else: + for each_node in self.nodes_in_edge(each_edge): + if 'ext_id' in self.node_attr(each_node)\ + and set([each_node, each_edge]) not in edge_list: + G.edge(each_edge, each_node) + edge_list.append(set([each_node, each_edge])) + for each_other_edge in self.adj_nodes(each_edge): + if set([each_edge, each_other_edge]) not in edge_list: + num_bond = 0 + common_node_set = set(self.nodes_in_edge(each_edge))\ + .intersection(set(self.nodes_in_edge(each_other_edge))) + for each_node in common_node_set: + if self.node_attr(each_node)['symbol'].bond_type in [1, 2, 3]: + num_bond += self.node_attr(each_node)['symbol'].bond_type + elif self.node_attr(each_node)['symbol'].bond_type in [12]: + num_bond += 1 + else: + raise NotImplementedError('unsupported bond type') + for _ in range(num_bond): + G.edge(each_edge, each_other_edge) + edge_list.append(set([each_edge, each_other_edge])) + if file_path is not None: + G.render(file_path, cleanup=True) + #os.remove(file_path) + return G + + def is_dividable(self, node): + _hg = deepcopy(self.hg) + _hg.remove_node(node) + return (not nx.is_connected(_hg)) + + def divide(self, node): + subhg_list = [] + + hg_wo_node = deepcopy(self) + hg_wo_node.remove_node(node, remove_connected_edges=False) + connected_components = nx.connected_components(hg_wo_node.hg) + for each_component in connected_components: + node_list = [node] + edge_list = [] + node_list.extend([each_node for each_node in each_component + if each_node.startswith('bond_')]) + edge_list.extend([each_edge for each_edge in each_component + if each_edge.startswith('e')]) + subhg_list.append(self.get_subhg(node_list, edge_list)) + #subhg_list[-1].set_node_attr(node, {'divided': True}) + return subhg_list + diff --git a/graph_grammar/io/__init__.py b/graph_grammar/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85e6131daba8a4f601ae72d37e6eb035d9503045 --- /dev/null +++ b/graph_grammar/io/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Jan 1 2018" + diff --git a/graph_grammar/io/__pycache__/__init__.cpython-310.pyc b/graph_grammar/io/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4f6e140d8246ba4cb6b25bcf7d89f2511581804 Binary files /dev/null and b/graph_grammar/io/__pycache__/__init__.cpython-310.pyc differ diff --git a/graph_grammar/io/__pycache__/smi.cpython-310.pyc b/graph_grammar/io/__pycache__/smi.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..682a16d3069fb140568cc08d7251f4bef4240126 Binary files /dev/null and b/graph_grammar/io/__pycache__/smi.cpython-310.pyc differ diff --git a/graph_grammar/io/smi.py b/graph_grammar/io/smi.py new file mode 100644 index 0000000000000000000000000000000000000000..dd17428fdcb7888365cffc59867976f415841d79 --- /dev/null +++ b/graph_grammar/io/smi.py @@ -0,0 +1,559 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Jan 12 2018" + +from copy import deepcopy +from rdkit import Chem +from rdkit import RDLogger +import networkx as nx +import numpy as np +from ..hypergraph import Hypergraph +from ..graph_grammar.symbols import TSymbol, BondSymbol + +# supress warnings +lg = RDLogger.logger() +lg.setLevel(RDLogger.CRITICAL) + + +class HGGen(object): + """ + load .smi file and yield a hypergraph. + + Attributes + ---------- + path_to_file : str + path to .smi file + kekulize : bool + kekulize or not + add_Hs : bool + add implicit hydrogens to the molecule or not. + all_single : bool + if True, all multiple bonds are summarized into a single bond with some attributes + + Yields + ------ + Hypergraph + """ + def __init__(self, path_to_file, kekulize=True, add_Hs=False, all_single=True): + self.num_line = 1 + self.mol_gen = Chem.SmilesMolSupplier(path_to_file, titleLine=False) + self.kekulize = kekulize + self.add_Hs = add_Hs + self.all_single = all_single + + def __iter__(self): + return self + + def __next__(self): + ''' + each_mol = None + while each_mol is None: + each_mol = next(self.mol_gen) + ''' + # not ignoring parse errors + each_mol = next(self.mol_gen) + if each_mol is None: + raise ValueError(f'incorrect smiles in line {self.num_line}') + else: + self.num_line += 1 + return mol_to_hg(each_mol, self.kekulize, self.add_Hs) + + +def mol_to_bipartite(mol, kekulize): + """ + get a bipartite representation of a molecule. + + Parameters + ---------- + mol : rdkit.Chem.rdchem.Mol + molecule object + + Returns + ------- + nx.Graph + a bipartite graph representing which bond is connected to which atoms. + """ + try: + mol = standardize_stereo(mol) + except KeyError: + print(Chem.MolToSmiles(mol)) + raise KeyError + + if kekulize: + Chem.Kekulize(mol) + + bipartite_g = nx.Graph() + for each_atom in mol.GetAtoms(): + bipartite_g.add_node(f"atom_{each_atom.GetIdx()}", + atom_attr=atom_attr(each_atom, kekulize)) + + for each_bond in mol.GetBonds(): + bond_idx = each_bond.GetIdx() + bipartite_g.add_node( + f"bond_{bond_idx}", + bond_attr=bond_attr(each_bond, kekulize)) + bipartite_g.add_edge( + f"atom_{each_bond.GetBeginAtomIdx()}", + f"bond_{bond_idx}") + bipartite_g.add_edge( + f"atom_{each_bond.GetEndAtomIdx()}", + f"bond_{bond_idx}") + return bipartite_g + + +def mol_to_hg(mol, kekulize, add_Hs): + """ + get a bipartite representation of a molecule. + + Parameters + ---------- + mol : rdkit.Chem.rdchem.Mol + molecule object + kekulize : bool + kekulize or not + add_Hs : bool + add implicit hydrogens to the molecule or not. + + Returns + ------- + Hypergraph + """ + if add_Hs: + mol = Chem.AddHs(mol) + + if kekulize: + Chem.Kekulize(mol) + + bipartite_g = mol_to_bipartite(mol, kekulize) + hg = Hypergraph() + for each_atom in [each_node for each_node in bipartite_g.nodes() + if each_node.startswith('atom_')]: + node_set = set([]) + for each_bond in bipartite_g.adj[each_atom]: + hg.add_node(each_bond, + attr_dict=bipartite_g.nodes[each_bond]['bond_attr']) + node_set.add(each_bond) + hg.add_edge(node_set, + attr_dict=bipartite_g.nodes[each_atom]['atom_attr']) + return hg + + +def hg_to_mol(hg, verbose=False): + """ convert a hypergraph into Mol object + + Parameters + ---------- + hg : Hypergraph + + Returns + ------- + mol : Chem.RWMol + """ + mol = Chem.RWMol() + atom_dict = {} + bond_set = set([]) + for each_edge in hg.edges: + atom = Chem.Atom(hg.edge_attr(each_edge)['symbol'].symbol) + atom.SetNumExplicitHs(hg.edge_attr(each_edge)['symbol'].num_explicit_Hs) + atom.SetFormalCharge(hg.edge_attr(each_edge)['symbol'].formal_charge) + atom.SetChiralTag( + Chem.rdchem.ChiralType.values[ + hg.edge_attr(each_edge)['symbol'].chirality]) + atom_idx = mol.AddAtom(atom) + atom_dict[each_edge] = atom_idx + + for each_node in hg.nodes: + edge_1, edge_2 = hg.adj_edges(each_node) + if edge_1+edge_2 not in bond_set: + if hg.node_attr(each_node)['symbol'].bond_type <= 3: + num_bond = hg.node_attr(each_node)['symbol'].bond_type + elif hg.node_attr(each_node)['symbol'].bond_type == 12: + num_bond = 1 + else: + raise ValueError(f'too many bonds; {hg.node_attr(each_node)["bond_symbol"].bond_type}') + _ = mol.AddBond(atom_dict[edge_1], + atom_dict[edge_2], + order=Chem.rdchem.BondType.values[num_bond]) + bond_idx = mol.GetBondBetweenAtoms(atom_dict[edge_1], atom_dict[edge_2]).GetIdx() + + # stereo + mol.GetBondWithIdx(bond_idx).SetStereo( + Chem.rdchem.BondStereo.values[hg.node_attr(each_node)['symbol'].stereo]) + bond_set.update([edge_1+edge_2]) + bond_set.update([edge_2+edge_1]) + mol.UpdatePropertyCache() + mol = mol.GetMol() + not_stereo_mol = deepcopy(mol) + if Chem.MolFromSmiles(Chem.MolToSmiles(not_stereo_mol)) is None: + raise RuntimeError('no valid molecule was obtained.') + try: + mol = set_stereo(mol) + is_stereo = True + except: + import traceback + traceback.print_exc() + is_stereo = False + mol_tmp = deepcopy(mol) + Chem.SetAromaticity(mol_tmp) + if Chem.MolFromSmiles(Chem.MolToSmiles(mol_tmp)) is not None: + mol = mol_tmp + else: + if Chem.MolFromSmiles(Chem.MolToSmiles(mol)) is None: + mol = not_stereo_mol + mol.UpdatePropertyCache() + Chem.GetSymmSSSR(mol) + mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol)) + if verbose: + return mol, is_stereo + else: + return mol + +def hgs_to_mols(hg_list, ignore_error=False): + if ignore_error: + mol_list = [] + for each_hg in hg_list: + try: + mol = hg_to_mol(each_hg) + except: + mol = None + mol_list.append(mol) + else: + mol_list = [hg_to_mol(each_hg) for each_hg in hg_list] + return mol_list + +def hgs_to_smiles(hg_list, ignore_error=False): + mol_list = hgs_to_mols(hg_list, ignore_error) + smiles_list = [] + for each_mol in mol_list: + try: + smiles_list.append( + Chem.MolToSmiles( + Chem.MolFromSmiles( + Chem.MolToSmiles( + each_mol)))) + except: + smiles_list.append(None) + return smiles_list + +def atom_attr(atom, kekulize): + """ + get atom's attributes + + Parameters + ---------- + atom : rdkit.Chem.rdchem.Atom + kekulize : bool + kekulize or not + + Returns + ------- + atom_attr : dict + "is_aromatic" : bool + the atom is aromatic or not. + "smarts" : str + SMARTS representation of the atom. + """ + if kekulize: + return {'terminal': True, + 'is_in_ring': atom.IsInRing(), + 'symbol': TSymbol(degree=0, + #degree=atom.GetTotalDegree(), + is_aromatic=False, + symbol=atom.GetSymbol(), + num_explicit_Hs=atom.GetNumExplicitHs(), + formal_charge=atom.GetFormalCharge(), + chirality=atom.GetChiralTag().real + )} + else: + return {'terminal': True, + 'is_in_ring': atom.IsInRing(), + 'symbol': TSymbol(degree=0, + #degree=atom.GetTotalDegree(), + is_aromatic=atom.GetIsAromatic(), + symbol=atom.GetSymbol(), + num_explicit_Hs=atom.GetNumExplicitHs(), + formal_charge=atom.GetFormalCharge(), + chirality=atom.GetChiralTag().real + )} + +def bond_attr(bond, kekulize): + """ + get atom's attributes + + Parameters + ---------- + bond : rdkit.Chem.rdchem.Bond + kekulize : bool + kekulize or not + + Returns + ------- + bond_attr : dict + "bond_type" : int + {0: rdkit.Chem.rdchem.BondType.UNSPECIFIED, + 1: rdkit.Chem.rdchem.BondType.SINGLE, + 2: rdkit.Chem.rdchem.BondType.DOUBLE, + 3: rdkit.Chem.rdchem.BondType.TRIPLE, + 4: rdkit.Chem.rdchem.BondType.QUADRUPLE, + 5: rdkit.Chem.rdchem.BondType.QUINTUPLE, + 6: rdkit.Chem.rdchem.BondType.HEXTUPLE, + 7: rdkit.Chem.rdchem.BondType.ONEANDAHALF, + 8: rdkit.Chem.rdchem.BondType.TWOANDAHALF, + 9: rdkit.Chem.rdchem.BondType.THREEANDAHALF, + 10: rdkit.Chem.rdchem.BondType.FOURANDAHALF, + 11: rdkit.Chem.rdchem.BondType.FIVEANDAHALF, + 12: rdkit.Chem.rdchem.BondType.AROMATIC, + 13: rdkit.Chem.rdchem.BondType.IONIC, + 14: rdkit.Chem.rdchem.BondType.HYDROGEN, + 15: rdkit.Chem.rdchem.BondType.THREECENTER, + 16: rdkit.Chem.rdchem.BondType.DATIVEONE, + 17: rdkit.Chem.rdchem.BondType.DATIVE, + 18: rdkit.Chem.rdchem.BondType.DATIVEL, + 19: rdkit.Chem.rdchem.BondType.DATIVER, + 20: rdkit.Chem.rdchem.BondType.OTHER, + 21: rdkit.Chem.rdchem.BondType.ZERO} + """ + if kekulize: + is_aromatic = False + if bond.GetBondType().real == 12: + bond_type = 1 + else: + bond_type = bond.GetBondType().real + else: + is_aromatic = bond.GetIsAromatic() + bond_type = bond.GetBondType().real + return {'symbol': BondSymbol(is_aromatic=is_aromatic, + bond_type=bond_type, + stereo=int(bond.GetStereo())), + 'is_in_ring': bond.IsInRing()} + + +def standardize_stereo(mol): + ''' + 0: rdkit.Chem.rdchem.BondDir.NONE, + 1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE, + 2: rdkit.Chem.rdchem.BondDir.BEGINDASH, + 3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT, + 4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT, + + ''' + # mol = Chem.AddHs(mol) # this removes CIPRank !!! + for each_bond in mol.GetBonds(): + if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E + begin_stereo_atom_idx = each_bond.GetBeginAtomIdx() + end_stereo_atom_idx = each_bond.GetEndAtomIdx() + atom_idx_1 = each_bond.GetStereoAtoms()[0] + atom_idx_2 = each_bond.GetStereoAtoms()[1] + if mol.GetBondBetweenAtoms(atom_idx_1, begin_stereo_atom_idx): + begin_atom_idx = atom_idx_1 + end_atom_idx = atom_idx_2 + else: + begin_atom_idx = atom_idx_2 + end_atom_idx = atom_idx_1 + + begin_another_atom_idx = None + assert len(mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors()) <= 3 + for each_neighbor in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors(): + each_neighbor_idx = each_neighbor.GetIdx() + if each_neighbor_idx not in [end_stereo_atom_idx, begin_atom_idx]: + begin_another_atom_idx = each_neighbor_idx + + end_another_atom_idx = None + assert len(mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors()) <= 3 + for each_neighbor in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors(): + each_neighbor_idx = each_neighbor.GetIdx() + if each_neighbor_idx not in [begin_stereo_atom_idx, end_atom_idx]: + end_another_atom_idx = each_neighbor_idx + + ''' + relationship between begin_atom_idx and end_atom_idx is encoded in GetStereo + ''' + begin_atom_rank = int(mol.GetAtomWithIdx(begin_atom_idx).GetProp('_CIPRank')) + end_atom_rank = int(mol.GetAtomWithIdx(end_atom_idx).GetProp('_CIPRank')) + try: + begin_another_atom_rank = int(mol.GetAtomWithIdx(begin_another_atom_idx).GetProp('_CIPRank')) + except: + begin_another_atom_rank = np.inf + try: + end_another_atom_rank = int(mol.GetAtomWithIdx(end_another_atom_idx).GetProp('_CIPRank')) + except: + end_another_atom_rank = np.inf + if begin_atom_rank < begin_another_atom_rank\ + and end_atom_rank < end_another_atom_rank: + pass + elif begin_atom_rank < begin_another_atom_rank\ + and end_atom_rank > end_another_atom_rank: + # (begin_atom_idx +) end_another_atom_idx should be in StereoAtoms + if each_bond.GetStereo() == 2: + # set stereo + each_bond.SetStereo(Chem.rdchem.BondStereo.values[3]) + # set bond dir + mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3) + mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0) + mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0) + mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3) + elif each_bond.GetStereo() == 3: + # set stereo + each_bond.SetStereo(Chem.rdchem.BondStereo.values[2]) + # set bond dir + mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3) + mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 0) + mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0) + mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4) + else: + raise ValueError + each_bond.SetStereoAtoms(begin_atom_idx, end_another_atom_idx) + elif begin_atom_rank > begin_another_atom_rank\ + and end_atom_rank < end_another_atom_rank: + # (end_atom_idx +) begin_another_atom_idx should be in StereoAtoms + if each_bond.GetStereo() == 2: + # set stereo + each_bond.SetStereo(Chem.rdchem.BondStereo.values[3]) + # set bond dir + mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0) + mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4) + mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4) + mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0) + elif each_bond.GetStereo() == 3: + # set stereo + each_bond.SetStereo(Chem.rdchem.BondStereo.values[2]) + # set bond dir + mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0) + mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4) + mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3) + mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 0) + else: + raise ValueError + each_bond.SetStereoAtoms(begin_another_atom_idx, end_atom_idx) + elif begin_atom_rank > begin_another_atom_rank\ + and end_atom_rank > end_another_atom_rank: + # begin_another_atom_idx + end_another_atom_idx should be in StereoAtoms + if each_bond.GetStereo() == 2: + # set bond dir + mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0) + mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4) + mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0) + mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 3) + elif each_bond.GetStereo() == 3: + # set bond dir + mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 0) + mol = safe_set_bond_dir(mol, begin_another_atom_idx, begin_stereo_atom_idx, 4) + mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 0) + mol = safe_set_bond_dir(mol, end_another_atom_idx, end_stereo_atom_idx, 4) + else: + raise ValueError + each_bond.SetStereoAtoms(begin_another_atom_idx, end_another_atom_idx) + else: + raise RuntimeError + return mol + + +def set_stereo(mol): + ''' + 0: rdkit.Chem.rdchem.BondDir.NONE, + 1: rdkit.Chem.rdchem.BondDir.BEGINWEDGE, + 2: rdkit.Chem.rdchem.BondDir.BEGINDASH, + 3: rdkit.Chem.rdchem.BondDir.ENDDOWNRIGHT, + 4: rdkit.Chem.rdchem.BondDir.ENDUPRIGHT, + ''' + _mol = Chem.MolFromSmiles(Chem.MolToSmiles(mol)) + Chem.Kekulize(_mol, True) + substruct_match = mol.GetSubstructMatch(_mol) + if not substruct_match: + ''' mol and _mol are kekulized. + sometimes, the order of '=' and '-' changes, which causes mol and _mol not matched. + ''' + Chem.SetAromaticity(mol) + Chem.SetAromaticity(_mol) + substruct_match = mol.GetSubstructMatch(_mol) + try: + atom_match = {substruct_match[_mol_atom_idx]: _mol_atom_idx for _mol_atom_idx in range(_mol.GetNumAtoms())} # mol to _mol + except: + raise ValueError('two molecules obtained from the same data do not match.') + + for each_bond in mol.GetBonds(): + begin_atom_idx = each_bond.GetBeginAtomIdx() + end_atom_idx = each_bond.GetEndAtomIdx() + _bond = _mol.GetBondBetweenAtoms(atom_match[begin_atom_idx], atom_match[end_atom_idx]) + _bond.SetStereo(each_bond.GetStereo()) + + mol = _mol + for each_bond in mol.GetBonds(): + if int(each_bond.GetStereo()) in [2, 3]: #2=Z (same side), 3=E + begin_stereo_atom_idx = each_bond.GetBeginAtomIdx() + end_stereo_atom_idx = each_bond.GetEndAtomIdx() + begin_atom_idx_set = set([each_neighbor.GetIdx() + for each_neighbor + in mol.GetAtomWithIdx(begin_stereo_atom_idx).GetNeighbors() + if each_neighbor.GetIdx() != end_stereo_atom_idx]) + end_atom_idx_set = set([each_neighbor.GetIdx() + for each_neighbor + in mol.GetAtomWithIdx(end_stereo_atom_idx).GetNeighbors() + if each_neighbor.GetIdx() != begin_stereo_atom_idx]) + if not begin_atom_idx_set: + each_bond.SetStereo(Chem.rdchem.BondStereo(0)) + continue + if not end_atom_idx_set: + each_bond.SetStereo(Chem.rdchem.BondStereo(0)) + continue + if len(begin_atom_idx_set) == 1: + begin_atom_idx = begin_atom_idx_set.pop() + begin_another_atom_idx = None + if len(end_atom_idx_set) == 1: + end_atom_idx = end_atom_idx_set.pop() + end_another_atom_idx = None + if len(begin_atom_idx_set) == 2: + atom_idx_1 = begin_atom_idx_set.pop() + atom_idx_2 = begin_atom_idx_set.pop() + if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')): + begin_atom_idx = atom_idx_1 + begin_another_atom_idx = atom_idx_2 + else: + begin_atom_idx = atom_idx_2 + begin_another_atom_idx = atom_idx_1 + if len(end_atom_idx_set) == 2: + atom_idx_1 = end_atom_idx_set.pop() + atom_idx_2 = end_atom_idx_set.pop() + if int(mol.GetAtomWithIdx(atom_idx_1).GetProp('_CIPRank')) < int(mol.GetAtomWithIdx(atom_idx_2).GetProp('_CIPRank')): + end_atom_idx = atom_idx_1 + end_another_atom_idx = atom_idx_2 + else: + end_atom_idx = atom_idx_2 + end_another_atom_idx = atom_idx_1 + + if each_bond.GetStereo() == 2: # same side + mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3) + mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 4) + each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx) + elif each_bond.GetStereo() == 3: # opposite side + mol = safe_set_bond_dir(mol, begin_atom_idx, begin_stereo_atom_idx, 3) + mol = safe_set_bond_dir(mol, end_atom_idx, end_stereo_atom_idx, 3) + each_bond.SetStereoAtoms(begin_atom_idx, end_atom_idx) + else: + raise ValueError + return mol + + +def safe_set_bond_dir(mol, atom_idx_1, atom_idx_2, bond_dir_val): + if atom_idx_1 is None or atom_idx_2 is None: + return mol + else: + mol.GetBondBetweenAtoms(atom_idx_1, atom_idx_2).SetBondDir(Chem.rdchem.BondDir.values[bond_dir_val]) + return mol + diff --git a/graph_grammar/nn/__init__.py b/graph_grammar/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a2912d326a2ed386143a635e5dc674b3ea7ce01f --- /dev/null +++ b/graph_grammar/nn/__init__.py @@ -0,0 +1,11 @@ +# -*- coding:utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" diff --git a/graph_grammar/nn/__pycache__/__init__.cpython-310.pyc b/graph_grammar/nn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a19d9c8a3800a7939024885b5a30651bd8bb1695 Binary files /dev/null and b/graph_grammar/nn/__pycache__/__init__.cpython-310.pyc differ diff --git a/graph_grammar/nn/__pycache__/decoder.cpython-310.pyc b/graph_grammar/nn/__pycache__/decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08e683bfd8bb50a3efb57086958589af1bc760b1 Binary files /dev/null and b/graph_grammar/nn/__pycache__/decoder.cpython-310.pyc differ diff --git a/graph_grammar/nn/__pycache__/encoder.cpython-310.pyc b/graph_grammar/nn/__pycache__/encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae5d350bbfea93dd453eacbf3ab4afa188f2e61e Binary files /dev/null and b/graph_grammar/nn/__pycache__/encoder.cpython-310.pyc differ diff --git a/graph_grammar/nn/dataset.py b/graph_grammar/nn/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..70894b289d6413dd29c81dfd3934574c6edb1383 --- /dev/null +++ b/graph_grammar/nn/dataset.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Apr 18 2018" + +from torch.utils.data import Dataset, DataLoader +import torch +import numpy as np + + +def left_padding(sentence_list, max_len, pad_idx=-1, inverse=False): + ''' pad left + + Parameters + ---------- + sentence_list : list of sequences of integers + max_len : int + maximum length of sentences. + if a sentence is shorter than `max_len`, its left part is padded. + pad_idx : int + integer for padding + inverse : bool + if True, the sequence is inversed. + + Returns + ------- + List of torch.LongTensor + each sentence is left-padded. + ''' + max_in_list = max([len(each_sen) for each_sen in sentence_list]) + + if max_in_list > max_len: + raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list)) + + if inverse: + return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen[::-1]) for each_sen in sentence_list] + else: + return [torch.LongTensor([pad_idx] * (max_len - len(each_sen)) + each_sen) for each_sen in sentence_list] + + +def right_padding(sentence_list, max_len, pad_idx=-1): + ''' pad right + + Parameters + ---------- + sentence_list : list of sequences of integers + max_len : int + maximum length of sentences. + if a sentence is shorter than `max_len`, its right part is padded. + pad_idx : int + integer for padding + + Returns + ------- + List of torch.LongTensor + each sentence is right-padded. + ''' + max_in_list = max([len(each_sen) for each_sen in sentence_list]) + if max_in_list > max_len: + raise ValueError('`max_len` should be larger than the maximum length of input sequences, {}.'.format(max_in_list)) + + return [torch.LongTensor(each_sen + [pad_idx] * (max_len - len(each_sen))) for each_sen in sentence_list] + + +class HRGDataset(Dataset): + + ''' + A class of HRG data + ''' + + def __init__(self, hrg, prod_rule_seq_list, max_len, target_val_list=None, inversed_input=False): + self.hrg = hrg + self.left_prod_rule_seq_list = left_padding(prod_rule_seq_list, + max_len, + inverse=inversed_input) + + self.right_prod_rule_seq_list = right_padding(prod_rule_seq_list, max_len) + self.inserved_input = inversed_input + self.target_val_list = target_val_list + if target_val_list is not None: + if len(prod_rule_seq_list) != len(target_val_list): + raise ValueError(f'prod_rule_seq_list and target_val_list have inconsistent lengths: {len(prod_rule_seq_list)}, {len(target_val_list)}') + + def __len__(self): + return len(self.left_prod_rule_seq_list) + + def __getitem__(self, idx): + if self.target_val_list is not None: + return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx], np.float32(self.target_val_list[idx]) + else: + return self.left_prod_rule_seq_list[idx], self.right_prod_rule_seq_list[idx] + + @property + def vocab_size(self): + return self.hrg.num_prod_rule + +def batch_padding(each_batch, batch_size, padding_idx): + num_pad = batch_size - len(each_batch[0]) + if num_pad: + each_batch[0] = torch.cat([each_batch[0], + padding_idx * torch.ones((batch_size - len(each_batch[0]), + len(each_batch[0][0])), dtype=torch.int64)], dim=0) + each_batch[1] = torch.cat([each_batch[1], + padding_idx * torch.ones((batch_size - len(each_batch[1]), + len(each_batch[1][0])), dtype=torch.int64)], dim=0) + return each_batch, num_pad diff --git a/graph_grammar/nn/decoder.py b/graph_grammar/nn/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a3b9c9d69e7077650355e1df201b18420de55471 --- /dev/null +++ b/graph_grammar/nn/decoder.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Aug 9 2018" + + +import abc +import numpy as np +import torch +from torch import nn + + +class DecoderBase(nn.Module): + + def __init__(self): + super().__init__() + self.hidden_dict = {} + + @abc.abstractmethod + def forward_one_step(self, tgt_emb_in): + ''' one-step forward model + + Parameters + ---------- + tgt_emb_in : Tensor, shape (batch_size, input_dim) + + Returns + ------- + Tensor, shape (batch_size, hidden_dim) + ''' + tgt_emb_out = None + return tgt_emb_out + + @abc.abstractmethod + def init_hidden(self): + ''' initialize the hidden states + ''' + pass + + @abc.abstractmethod + def feed_hidden(self, hidden_dict_0): + for each_hidden in self.hidden_dict.keys(): + self.hidden_dict[each_hidden][0] = hidden_dict_0[each_hidden] + + +class GRUDecoder(DecoderBase): + + def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, + dropout: float, batch_size: int, use_gpu: bool, + no_dropout=False): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.dropout = dropout + self.batch_size = batch_size + self.use_gpu = use_gpu + self.model = nn.GRU(input_size=self.input_dim, + hidden_size=self.hidden_dim, + num_layers=self.num_layers, + batch_first=True, + bidirectional=False, + dropout=self.dropout if not no_dropout else 0 + ) + if self.use_gpu: + self.model.cuda() + self.init_hidden() + + def init_hidden(self): + self.hidden_dict['h'] = torch.zeros((self.num_layers, + self.batch_size, + self.hidden_dim), + requires_grad=False) + if self.use_gpu: + self.hidden_dict['h'] = self.hidden_dict['h'].cuda() + + def forward_one_step(self, tgt_emb_in): + ''' one-step forward model + + Parameters + ---------- + tgt_emb_in : Tensor, shape (batch_size, input_dim) + + Returns + ------- + Tensor, shape (batch_size, hidden_dim) + ''' + tgt_emb_out, self.hidden_dict['h'] \ + = self.model(tgt_emb_in.view(self.batch_size, 1, -1), + self.hidden_dict['h']) + return tgt_emb_out + + +class LSTMDecoder(DecoderBase): + + def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, + dropout: float, batch_size: int, use_gpu: bool, + no_dropout=False): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.dropout = dropout + self.batch_size = batch_size + self.use_gpu = use_gpu + self.model = nn.LSTM(input_size=self.input_dim, + hidden_size=self.hidden_dim, + num_layers=self.num_layers, + batch_first=True, + bidirectional=False, + dropout=self.dropout if not no_dropout else 0) + if self.use_gpu: + self.model.cuda() + self.init_hidden() + + def init_hidden(self): + self.hidden_dict['h'] = torch.zeros((self.num_layers, + self.batch_size, + self.hidden_dim), + requires_grad=False) + self.hidden_dict['c'] = torch.zeros((self.num_layers, + self.batch_size, + self.hidden_dim), + requires_grad=False) + if self.use_gpu: + for each_hidden in self.hidden_dict.keys(): + self.hidden_dict[each_hidden] = self.hidden_dict[each_hidden].cuda() + + def forward_one_step(self, tgt_emb_in): + ''' one-step forward model + + Parameters + ---------- + tgt_emb_in : Tensor, shape (batch_size, input_dim) + + Returns + ------- + Tensor, shape (batch_size, hidden_dim) + ''' + tgt_hidden_out, self.hidden_dict['h'], self.hidden_dict['c'] \ + = self.model(tgt_emb_in.view(self.batch_size, 1, -1), + self.hidden_dict['h'], self.hidden_dict['c']) + return tgt_hidden_out diff --git a/graph_grammar/nn/encoder.py b/graph_grammar/nn/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a59226d9980479503ed482be26976dd4917b9953 --- /dev/null +++ b/graph_grammar/nn/encoder.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Aug 9 2018" + + +import abc +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from typing import List + + +class EncoderBase(nn.Module): + + def __init__(self): + super().__init__() + + @abc.abstractmethod + def forward(self, in_seq): + ''' forward model + + Parameters + ---------- + in_seq_emb : Variable, shape (batch_size, max_len, input_dim) + + Returns + ------- + hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim) + ''' + pass + + @abc.abstractmethod + def init_hidden(self): + ''' initialize the hidden states + ''' + pass + + +class GRUEncoder(EncoderBase): + + def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, + bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool, + no_dropout=False): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.bidirectional = bidirectional + self.dropout = dropout + self.batch_size = batch_size + self.use_gpu = use_gpu + self.model = nn.GRU(input_size=self.input_dim, + hidden_size=self.hidden_dim, + num_layers=self.num_layers, + batch_first=True, + bidirectional=self.bidirectional, + dropout=self.dropout if not no_dropout else 0) + if self.use_gpu: + self.model.cuda() + self.init_hidden() + + + def init_hidden(self): + self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers, + self.batch_size, + self.hidden_dim), + requires_grad=False) + if self.use_gpu: + self.h0 = self.h0.cuda() + + def forward(self, in_seq_emb): + ''' forward model + + Parameters + ---------- + in_seq_emb : Tensor, shape (batch_size, max_len, input_dim) + + Returns + ------- + hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim) + ''' + max_len = in_seq_emb.size(1) + hidden_seq_emb, self.h0 = self.model( + in_seq_emb, self.h0) + hidden_seq_emb = hidden_seq_emb.view(self.batch_size, + max_len, + 1 + self.bidirectional, + self.hidden_dim) + return hidden_seq_emb + + +class LSTMEncoder(EncoderBase): + + def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, + bidirectional: bool, dropout: float, batch_size: int, use_gpu: bool, + no_dropout=False): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.bidirectional = bidirectional + self.dropout = dropout + self.batch_size = batch_size + self.use_gpu = use_gpu + self.model = nn.LSTM(input_size=self.input_dim, + hidden_size=self.hidden_dim, + num_layers=self.num_layers, + batch_first=True, + bidirectional=self.bidirectional, + dropout=self.dropout if not no_dropout else 0) + if self.use_gpu: + self.model.cuda() + self.init_hidden() + + def init_hidden(self): + self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers, + self.batch_size, + self.hidden_dim), + requires_grad=False) + self.c0 = torch.zeros(((self.bidirectional + 1) * self.num_layers, + self.batch_size, + self.hidden_dim), + requires_grad=False) + if self.use_gpu: + self.h0 = self.h0.cuda() + self.c0 = self.c0.cuda() + + def forward(self, in_seq_emb): + ''' forward model + + Parameters + ---------- + in_seq_emb : Tensor, shape (batch_size, max_len, input_dim) + + Returns + ------- + hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim) + ''' + max_len = in_seq_emb.size(1) + hidden_seq_emb, (self.h0, self.c0) = self.model( + in_seq_emb, (self.h0, self.c0)) + hidden_seq_emb = hidden_seq_emb.view(self.batch_size, + max_len, + 1 + self.bidirectional, + self.hidden_dim) + return hidden_seq_emb + + +class FullConnectedEncoder(EncoderBase): + + def __init__(self, input_dim: int, hidden_dim: int, max_len: int, hidden_dim_list: List[int], + batch_size: int, use_gpu: bool): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.max_len = max_len + self.hidden_dim_list = hidden_dim_list + self.use_gpu = use_gpu + in_out_dim_list = [input_dim * max_len] + list(hidden_dim_list) + [hidden_dim] + self.linear_list = nn.ModuleList( + [nn.Linear(in_out_dim_list[each_idx], in_out_dim_list[each_idx + 1])\ + for each_idx in range(len(in_out_dim_list) - 1)]) + + def forward(self, in_seq_emb): + ''' forward model + + Parameters + ---------- + in_seq_emb : Tensor, shape (batch_size, max_len, input_dim) + + Returns + ------- + hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim) + ''' + batch_size = in_seq_emb.size(0) + x = in_seq_emb.view(batch_size, -1) + for each_linear in self.linear_list: + x = F.relu(each_linear(x)) + return x.view(batch_size, 1, -1) + + def init_hidden(self): + pass diff --git a/graph_grammar/nn/graph.py b/graph_grammar/nn/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..a2da699b2000ddf81c856c2c636f8ccdb864c81c --- /dev/null +++ b/graph_grammar/nn/graph.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES THE ORIGINAL SOURCE CODE (AND SOME ADAPTATIONS) +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +""" Title """ + +__author__ = "Hiroshi Kajino " +__copyright__ = "(c) Copyright IBM Corp. 2018" +__version__ = "0.1" +__date__ = "Jan 1 2018" + +import numpy as np +import torch +import torch.nn.functional as F +from graph_grammar.graph_grammar.hrg import ProductionRuleCorpus +from torch import nn +from torch.autograd import Variable + +class MolecularProdRuleEmbedding(nn.Module): + + ''' molecular fingerprint layer + ''' + + def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation, + out_dim=32, element_embed_dim=32, + num_layers=3, padding_idx=None, use_gpu=False): + super().__init__() + if padding_idx is not None: + assert padding_idx == -1, 'padding_idx must be -1.' + self.prod_rule_corpus = prod_rule_corpus + self.layer2layer_activation = layer2layer_activation + self.layer2out_activation = layer2out_activation + self.out_dim = out_dim + self.element_embed_dim = element_embed_dim + self.num_layers = num_layers + self.padding_idx = padding_idx + self.use_gpu = use_gpu + + self.layer2layer_list = [] + self.layer2out_list = [] + + if self.use_gpu: + self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol, + self.element_embed_dim, requires_grad=True).cuda() + self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol, + self.element_embed_dim, requires_grad=True).cuda() + self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id, + self.element_embed_dim, requires_grad=True).cuda() + for _ in range(num_layers): + self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda()) + self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda()) + else: + self.atom_embed = torch.randn(self.prod_rule_corpus.num_edge_symbol, + self.element_embed_dim, requires_grad=True) + self.bond_embed = torch.randn(self.prod_rule_corpus.num_node_symbol, + self.element_embed_dim, requires_grad=True) + self.ext_id_embed = torch.randn(self.prod_rule_corpus.num_ext_id, + self.element_embed_dim, requires_grad=True) + for _ in range(num_layers): + self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim)) + self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim)) + + + def forward(self, prod_rule_idx_seq): + ''' forward model for mini-batch + + Parameters + ---------- + prod_rule_idx_seq : (batch_size, length) + + Returns + ------- + Variable, shape (batch_size, length, out_dim) + ''' + batch_size, length = prod_rule_idx_seq.shape + if self.use_gpu: + out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda() + else: + out = Variable(torch.zeros((batch_size, length, self.out_dim))) + for each_batch_idx in range(batch_size): + for each_idx in range(length): + if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list): + continue + else: + each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])] + layer_wise_embed_dict = {each_edge: self.atom_embed[ + each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']] + for each_edge in each_prod_rule.rhs.edges} + layer_wise_embed_dict.update({each_node: self.bond_embed[ + each_prod_rule.rhs.node_attr(each_node)['symbol_idx']] + for each_node in each_prod_rule.rhs.nodes}) + for each_node in each_prod_rule.rhs.nodes: + if 'ext_id' in each_prod_rule.rhs.node_attr(each_node): + layer_wise_embed_dict[each_node] \ + = layer_wise_embed_dict[each_node] \ + + self.ext_id_embed[each_prod_rule.rhs.node_attr(each_node)['ext_id']] + + for each_layer in range(self.num_layers): + next_layer_embed_dict = {} + for each_edge in each_prod_rule.rhs.edges: + v = layer_wise_embed_dict[each_edge] + for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge): + v = v + layer_wise_embed_dict[each_node] + next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v)) + out[each_batch_idx, each_idx, :] \ + = out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v)) + for each_node in each_prod_rule.rhs.nodes: + v = layer_wise_embed_dict[each_node] + for each_edge in each_prod_rule.rhs.adj_edges(each_node): + v = v + layer_wise_embed_dict[each_edge] + next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v)) + out[each_batch_idx, each_idx, :]\ + = out[each_batch_idx, each_idx, :] + self.layer2out_activation(self.layer2out_list[each_layer](v)) + layer_wise_embed_dict = next_layer_embed_dict + + return out + + +class MolecularProdRuleEmbeddingLastLayer(nn.Module): + + ''' molecular fingerprint layer + ''' + + def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation, + out_dim=32, element_embed_dim=32, + num_layers=3, padding_idx=None, use_gpu=False): + super().__init__() + if padding_idx is not None: + assert padding_idx == -1, 'padding_idx must be -1.' + self.prod_rule_corpus = prod_rule_corpus + self.layer2layer_activation = layer2layer_activation + self.layer2out_activation = layer2out_activation + self.out_dim = out_dim + self.element_embed_dim = element_embed_dim + self.num_layers = num_layers + self.padding_idx = padding_idx + self.use_gpu = use_gpu + + self.layer2layer_list = [] + self.layer2out_list = [] + + if self.use_gpu: + self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim).cuda() + self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim).cuda() + for _ in range(num_layers+1): + self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim).cuda()) + self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim).cuda()) + else: + self.atom_embed = nn.Embedding(self.prod_rule_corpus.num_edge_symbol, self.element_embed_dim) + self.bond_embed = nn.Embedding(self.prod_rule_corpus.num_node_symbol, self.element_embed_dim) + for _ in range(num_layers+1): + self.layer2layer_list.append(nn.Linear(self.element_embed_dim, self.element_embed_dim)) + self.layer2out_list.append(nn.Linear(self.element_embed_dim, self.out_dim)) + + + def forward(self, prod_rule_idx_seq): + ''' forward model for mini-batch + + Parameters + ---------- + prod_rule_idx_seq : (batch_size, length) + + Returns + ------- + Variable, shape (batch_size, length, out_dim) + ''' + batch_size, length = prod_rule_idx_seq.shape + if self.use_gpu: + out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda() + else: + out = Variable(torch.zeros((batch_size, length, self.out_dim))) + for each_batch_idx in range(batch_size): + for each_idx in range(length): + if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list): + continue + else: + each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])] + + if self.use_gpu: + layer_wise_embed_dict = {each_edge: self.atom_embed( + Variable(torch.LongTensor( + [each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']] + ), requires_grad=False).cuda()) + for each_edge in each_prod_rule.rhs.edges} + layer_wise_embed_dict.update({each_node: self.bond_embed( + Variable( + torch.LongTensor([ + each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]), + requires_grad=False).cuda() + ) for each_node in each_prod_rule.rhs.nodes}) + else: + layer_wise_embed_dict = {each_edge: self.atom_embed( + Variable(torch.LongTensor( + [each_prod_rule.rhs.edge_attr(each_edge)['symbol_idx']] + ), requires_grad=False)) + for each_edge in each_prod_rule.rhs.edges} + layer_wise_embed_dict.update({each_node: self.bond_embed( + Variable( + torch.LongTensor([ + each_prod_rule.rhs.node_attr(each_node)['symbol_idx']]), + requires_grad=False) + ) for each_node in each_prod_rule.rhs.nodes}) + + for each_layer in range(self.num_layers): + next_layer_embed_dict = {} + for each_edge in each_prod_rule.rhs.edges: + v = layer_wise_embed_dict[each_edge] + for each_node in each_prod_rule.rhs.nodes_in_edge(each_edge): + v += layer_wise_embed_dict[each_node] + next_layer_embed_dict[each_edge] = self.layer2layer_activation(self.layer2layer_list[each_layer](v)) + for each_node in each_prod_rule.rhs.nodes: + v = layer_wise_embed_dict[each_node] + for each_edge in each_prod_rule.rhs.adj_edges(each_node): + v += layer_wise_embed_dict[each_edge] + next_layer_embed_dict[each_node] = self.layer2layer_activation(self.layer2layer_list[each_layer](v)) + layer_wise_embed_dict = next_layer_embed_dict + for each_edge in each_prod_rule.rhs.edges: + out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v)) + for each_edge in each_prod_rule.rhs.edges: + out[each_batch_idx, each_idx, :] = self.layer2out_activation(self.layer2out_list[self.num_layers](v)) + + return out + + +class MolecularProdRuleEmbeddingUsingFeatures(nn.Module): + + ''' molecular fingerprint layer + ''' + + def __init__(self, prod_rule_corpus, layer2layer_activation, layer2out_activation, + out_dim=32, num_layers=3, padding_idx=None, use_gpu=False): + super().__init__() + if padding_idx is not None: + assert padding_idx == -1, 'padding_idx must be -1.' + self.feature_dict, self.feature_dim = prod_rule_corpus.construct_feature_vectors() + self.prod_rule_corpus = prod_rule_corpus + self.layer2layer_activation = layer2layer_activation + self.layer2out_activation = layer2out_activation + self.out_dim = out_dim + self.num_layers = num_layers + self.padding_idx = padding_idx + self.use_gpu = use_gpu + + self.layer2layer_list = [] + self.layer2out_list = [] + + if self.use_gpu: + for each_key in self.feature_dict: + self.feature_dict[each_key] = self.feature_dict[each_key].to_dense().cuda() + for _ in range(num_layers): + self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim).cuda()) + self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim).cuda()) + else: + for _ in range(num_layers): + self.layer2layer_list.append(nn.Linear(self.feature_dim, self.feature_dim)) + self.layer2out_list.append(nn.Linear(self.feature_dim, self.out_dim)) + + + def forward(self, prod_rule_idx_seq): + ''' forward model for mini-batch + + Parameters + ---------- + prod_rule_idx_seq : (batch_size, length) + + Returns + ------- + Variable, shape (batch_size, length, out_dim) + ''' + batch_size, length = prod_rule_idx_seq.shape + if self.use_gpu: + out = Variable(torch.zeros((batch_size, length, self.out_dim))).cuda() + else: + out = Variable(torch.zeros((batch_size, length, self.out_dim))) + for each_batch_idx in range(batch_size): + for each_idx in range(length): + if int(prod_rule_idx_seq[each_batch_idx, each_idx]) == len(self.prod_rule_corpus.prod_rule_list): + continue + else: + each_prod_rule = self.prod_rule_corpus.prod_rule_list[int(prod_rule_idx_seq[each_batch_idx, each_idx])] + edge_list = sorted(list(each_prod_rule.rhs.edges)) + node_list = sorted(list(each_prod_rule.rhs.nodes)) + adj_mat = torch.FloatTensor(each_prod_rule.rhs_adj_mat(edge_list + node_list).todense() + np.identity(len(edge_list)+len(node_list))) + if self.use_gpu: + adj_mat = adj_mat.cuda() + layer_wise_embed = [ + self.feature_dict[each_prod_rule.rhs.edge_attr(each_edge)['symbol']] + for each_edge in edge_list]\ + + [self.feature_dict[each_prod_rule.rhs.node_attr(each_node)['symbol']] + for each_node in node_list] + for each_node in each_prod_rule.ext_node.values(): + layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \ + = layer_wise_embed[each_prod_rule.rhs.num_edges + node_list.index(each_node)] \ + + self.feature_dict[('ext_id', each_prod_rule.rhs.node_attr(each_node)['ext_id'])] + layer_wise_embed = torch.stack(layer_wise_embed) + + for each_layer in range(self.num_layers): + message = adj_mat @ layer_wise_embed + next_layer_embed = self.layer2layer_activation(self.layer2layer_list[each_layer](message)) + out[each_batch_idx, each_idx, :] \ + = out[each_batch_idx, each_idx, :] \ + + self.layer2out_activation(self.layer2out_list[each_layer](message)).sum(dim=0) + layer_wise_embed = next_layer_embed + return out diff --git a/images/mhg_example.png b/images/mhg_example.png new file mode 100644 index 0000000000000000000000000000000000000000..3a7dd8ce73476fba75ed242e67147946d99740eb Binary files /dev/null and b/images/mhg_example.png differ diff --git a/images/mhg_example1.png b/images/mhg_example1.png new file mode 100644 index 0000000000000000000000000000000000000000..150b71f10580655433a6f59a60cbc2afc07d8dc8 Binary files /dev/null and b/images/mhg_example1.png differ diff --git a/images/mhg_example2.png b/images/mhg_example2.png new file mode 100644 index 0000000000000000000000000000000000000000..b00f97a7fb3bec25c0e6e42990d18aaa216eff2d Binary files /dev/null and b/images/mhg_example2.png differ diff --git a/load.py b/load.py new file mode 100644 index 0000000000000000000000000000000000000000..fe3208aa7b5c5efc97691fec0069c20f56f64c15 --- /dev/null +++ b/load.py @@ -0,0 +1,83 @@ +# -*- coding:utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +import os +import pickle +import sys + +from rdkit import Chem +import torch +from torch_geometric.utils.smiles import from_smiles + +from typing import Any, Dict, List, Optional, Union +from typing_extensions import Self + +from .graph_grammar.io.smi import hg_to_mol +from .models.mhgvae import GrammarGINVAE + + +class PretrainedModelWrapper: + model: GrammarGINVAE + + def __init__(self, model_dict: Dict[str, Any]) -> None: + json_params = model_dict['gnn_params'] + encoder_params = json_params['encoder_params'] + encoder_params['node_feature_size'] = model_dict['num_features'] + encoder_params['edge_feature_size'] = model_dict['num_edge_features'] + self.model = GrammarGINVAE(model_dict['hrg'], rank=-1, encoder_params=encoder_params, + decoder_params=json_params['decoder_params'], + prod_rule_embed_params=json_params["prod_rule_embed_params"], + batch_size=512, max_len=model_dict['max_length']) + self.model.load_state_dict(model_dict['model_state_dict']) + + self.model.eval() + + def to(self, device: Union[str, int, torch.device]) -> Self: + dev_type = type(device) + if dev_type != torch.device: + if dev_type == str or torch.cuda.is_available(): + device = torch.device(device) + else: + device = torch.device("mps", device) + + self.model = self.model.to(device) + return self + + def encode(self, data: List[str]) -> List[torch.tensor]: + # Need to encode them into a graph nn + output = [] + for d in data: + params = next(self.model.parameters()) + g = from_smiles(d) + if (g.cpu() and params != 'cpu') or (not g.cpu() and params == 'cpu'): + g.to(params.device) + ltvec = self.model.graph_embed(g.x, g.edge_index, g.edge_attr, g.batch) + output.append(ltvec[0]) + return output + + def decode(self, data: List[torch.tensor]) -> List[str]: + output = [] + for d in data: + mu, logvar = self.model.get_mean_var(d.unsqueeze(0)) + z = self.model.reparameterize(mu, logvar) + flags, _, hgs = self.model.decode(z) + if flags[0]: + reconstructed_mol, _ = hg_to_mol(hgs[0], True) + output.append(Chem.MolToSmiles(reconstructed_mol)) + else: + output.append(None) + return output + + +def load(model_name: str = "models/mhg_model/pickles/mhggnn_pretrained_model_0724_2023.pickle") -> Optional[ + PretrainedModelWrapper]: + for p in sys.path: + file = p + "/" + model_name + if os.path.isfile(file): + with open(file, "rb") as f: + model_dict = pickle.load(f) + return PretrainedModelWrapper(model_dict) + return None diff --git a/mhg_gnn.egg-info/PKG-INFO b/mhg_gnn.egg-info/PKG-INFO new file mode 100644 index 0000000000000000000000000000000000000000..205fc9ded7ff7f8a67d589d1bc0fa997a77a067c --- /dev/null +++ b/mhg_gnn.egg-info/PKG-INFO @@ -0,0 +1,102 @@ +Metadata-Version: 2.1 +Name: mhg-gnn +Version: 0.0 +Summary: Package for mhg-gnn +Author: team +License: TBD +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.9 +Description-Content-Type: text/markdown +Requires-Dist: networkx>=2.8 +Requires-Dist: numpy<2.0.0,>=1.23.5 +Requires-Dist: pandas>=1.5.3 +Requires-Dist: rdkit-pypi<2023.9.6,>=2022.9.4 +Requires-Dist: torch>=2.0.0 +Requires-Dist: torchinfo>=1.8.0 +Requires-Dist: torch-geometric>=2.3.1 + +# mhg-gnn + +This repository provides PyTorch source code assosiated with our publication, "MHG-GNN: Combination of Molecular Hypergraph Grammar with Graph Neural Network" + +**Paper:** [Arxiv Link](https://arxiv.org/pdf/2309.16374) + +For more information contact: SEIJITKD@jp.ibm.com + +![mhg-gnn](images/mhg_example1.png) + +## Introduction + +We present MHG-GNN, an autoencoder architecture +that has an encoder based on GNN and a decoder based on a sequential model with MHG. +Since the encoder is a GNN variant, MHG-GNN can accept any molecule as input, and +demonstrate high predictive performance on molecular graph data. +In addition, the decoder inherits the theoretical guarantee of MHG on always generating a structurally valid molecule as output. + +## Table of Contents + +1. [Getting Started](#getting-started) + 1. [Pretrained Models and Training Logs](#pretrained-models-and-training-logs) + 2. [Replicating Conda Environment](#replicating-conda-environment) +2. [Feature Extraction](#feature-extraction) + +## Getting Started + +**This code and environment have been tested on Intel E5-2667 CPUs at 3.30GHz and NVIDIA A100 Tensor Core GPUs.** + +### Pretrained Models and Training Logs + +We provide checkpoints of the MHG-GNN model pre-trained on a dataset of ~1.34M molecules curated from PubChem. (later) For model weights: [HuggingFace Link]() + +Add the MHG-GNN `pre-trained weights.pt` to the `models/` directory according to your needs. + +### Replacicating Conda Environment + +Follow these steps to replicate our Conda environment and install the necessary libraries: + +``` +conda create --name mhg-gnn-env python=3.8.18 +conda activate mhg-gnn-env +``` + +#### Install Packages with Conda + +``` +conda install -c conda-forge networkx=2.8 +conda install numpy=1.23.5 +# conda install -c conda-forge rdkit=2022.9.4 +conda install pytorch=2.0.0 torchvision torchaudio -c pytorch +conda install -c conda-forge torchinfo=1.8.0 +conda install pyg -c pyg +``` + +#### Install Packages with pip +``` +pip install rdkit torch-nl==0.3 torch-scatter torch-sparse +``` + +## Feature Extraction + +The example notebook [mhg-gnn_encoder_decoder_example.ipynb](notebooks/mhg-gnn_encoder_decoder_example.ipynb) contains code to load checkpoint files and use the pre-trained model for encoder and decoder tasks. + +To load mhg-gnn, you can simply use: + +```python +import torch +import load + +model = load.load() +``` + +To encode SMILES into embeddings, you can use: + +```python +with torch.no_grad(): + repr = model.encode(["CCO", "O=C=O", "OC(=O)c1ccccc1C(=O)O"]) +``` + +For decoder, you can use the function, so you can return from embeddings to SMILES strings: + +```python +orig = model.decode(repr) +``` diff --git a/mhg_gnn.egg-info/SOURCES.txt b/mhg_gnn.egg-info/SOURCES.txt new file mode 100644 index 0000000000000000000000000000000000000000..f6429c60d226eadd5f9fce9ba11de93451412c34 --- /dev/null +++ b/mhg_gnn.egg-info/SOURCES.txt @@ -0,0 +1,46 @@ +README.md +setup.cfg +setup.py +./graph_grammar/__init__.py +./graph_grammar/hypergraph.py +./graph_grammar/algo/__init__.py +./graph_grammar/algo/tree_decomposition.py +./graph_grammar/graph_grammar/__init__.py +./graph_grammar/graph_grammar/base.py +./graph_grammar/graph_grammar/corpus.py +./graph_grammar/graph_grammar/hrg.py +./graph_grammar/graph_grammar/symbols.py +./graph_grammar/graph_grammar/utils.py +./graph_grammar/io/__init__.py +./graph_grammar/io/smi.py +./graph_grammar/nn/__init__.py +./graph_grammar/nn/dataset.py +./graph_grammar/nn/decoder.py +./graph_grammar/nn/encoder.py +./graph_grammar/nn/graph.py +./models/__init__.py +./models/mhgvae.py +graph_grammar/__init__.py +graph_grammar/hypergraph.py +graph_grammar/algo/__init__.py +graph_grammar/algo/tree_decomposition.py +graph_grammar/graph_grammar/__init__.py +graph_grammar/graph_grammar/base.py +graph_grammar/graph_grammar/corpus.py +graph_grammar/graph_grammar/hrg.py +graph_grammar/graph_grammar/symbols.py +graph_grammar/graph_grammar/utils.py +graph_grammar/io/__init__.py +graph_grammar/io/smi.py +graph_grammar/nn/__init__.py +graph_grammar/nn/dataset.py +graph_grammar/nn/decoder.py +graph_grammar/nn/encoder.py +graph_grammar/nn/graph.py +mhg_gnn.egg-info/PKG-INFO +mhg_gnn.egg-info/SOURCES.txt +mhg_gnn.egg-info/dependency_links.txt +mhg_gnn.egg-info/requires.txt +mhg_gnn.egg-info/top_level.txt +models/__init__.py +models/mhgvae.py \ No newline at end of file diff --git a/mhg_gnn.egg-info/dependency_links.txt b/mhg_gnn.egg-info/dependency_links.txt new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/mhg_gnn.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/mhg_gnn.egg-info/requires.txt b/mhg_gnn.egg-info/requires.txt new file mode 100644 index 0000000000000000000000000000000000000000..54aa1b371905a3d46c6cbc15741cfb9b8a376c7d --- /dev/null +++ b/mhg_gnn.egg-info/requires.txt @@ -0,0 +1,7 @@ +networkx>=2.8 +numpy<2.0.0,>=1.23.5 +pandas>=1.5.3 +rdkit-pypi<2023.9.6,>=2022.9.4 +torch>=2.0.0 +torchinfo>=1.8.0 +torch-geometric>=2.3.1 diff --git a/mhg_gnn.egg-info/top_level.txt b/mhg_gnn.egg-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..d606741958961ccbe1f156f690406bcd74658ad4 --- /dev/null +++ b/mhg_gnn.egg-info/top_level.txt @@ -0,0 +1,2 @@ +graph_grammar +models diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1c71cf8133e8cd3edbcb23b4ecf2cd326ec0316 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,5 @@ +# -*- coding:utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# diff --git a/models/__pycache__/__init__.cpython-310.pyc b/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecceb80a0ac40826b1514c4006fb28c2089b66bf Binary files /dev/null and b/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/models/__pycache__/mhgvae.cpython-310.pyc b/models/__pycache__/mhgvae.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1de425d10bca4341388de64d4d0081f16fe663ea Binary files /dev/null and b/models/__pycache__/mhgvae.cpython-310.pyc differ diff --git a/models/mhgvae.py b/models/mhgvae.py new file mode 100644 index 0000000000000000000000000000000000000000..829a2c5567e7ffb2b61842840b83d428c2e2cbe0 --- /dev/null +++ b/models/mhgvae.py @@ -0,0 +1,956 @@ +# -*- coding:utf-8 -*- +# Rhizome +# Version beta 0.0, August 2023 +# Property of IBM Research, Accelerated Discovery +# + +""" +PLEASE NOTE THIS IMPLEMENTATION INCLUDES ADAPTED SOURCE CODE +OF THE MHG IMPLEMENTATION OF HIROSHI KAJINO AT IBM TRL ALREADY PUBLICLY AVAILABLE, +E.G., GRUEncoder/GRUDecoder, GrammarSeq2SeqVAE AND EVEN SOME METHODS OF GrammarGINVAE. +THIS MIGHT INFLUENCE THE DECISION OF THE FINAL LICENSE SO CAREFUL CHECK NEEDS BE DONE. +""" + +import numpy as np +import logging + +import torch +from torch.autograd import Variable +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.loss import _Loss + +from torch_geometric.nn import MessagePassing +from torch_geometric.nn import global_add_pool + + +from ..graph_grammar.graph_grammar.symbols import NTSymbol +from ..graph_grammar.nn.encoder import EncoderBase +from ..graph_grammar.nn.decoder import DecoderBase + +def get_atom_edge_feature_dims(): + from torch_geometric.utils.smiles import x_map, e_map + func = lambda x: len(x[1]) + return list(map(func, x_map.items())), list(map(func, e_map.items())) + + +class FeatureEmbedding(nn.Module): + def __init__(self, input_dims, embedded_dim): + super().__init__() + self.embedding_list = nn.ModuleList() + for dim in input_dims: + embedding = nn.Embedding(dim, embedded_dim) + self.embedding_list.append(embedding) + + def forward(self, x): + output = 0 + for i in range(x.shape[1]): + input = x[:, i].to(torch.int) + device = next(self.parameters()).device + if device != input.device: + input = input.to(device) + emb = self.embedding_list[i](input) + output += emb + return output + + +class GRUEncoder(EncoderBase): + + def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, + bidirectional: bool, dropout: float, batch_size: int, rank: int=-1, + no_dropout: bool=False): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.bidirectional = bidirectional + self.dropout = dropout + self.batch_size = batch_size + self.rank = rank + self.model = nn.GRU(input_size=self.input_dim, + hidden_size=self.hidden_dim, + num_layers=self.num_layers, + batch_first=True, + bidirectional=self.bidirectional, + dropout=self.dropout if not no_dropout else 0) + if self.rank >= 0: + if torch.cuda.is_available(): + self.model = self.model.to(rank) + else: + # support mac mps + self.model = self.model.to(torch.device("mps", rank)) + self.init_hidden(self.batch_size) + + def init_hidden(self, bsize): + self.h0 = torch.zeros(((self.bidirectional + 1) * self.num_layers, + min(self.batch_size, bsize), + self.hidden_dim), + requires_grad=False) + if self.rank >= 0: + if torch.cuda.is_available(): + self.h0 = self.h0.to(self.rank) + else: + # support mac mps + self.h0 = self.h0.to(torch.device("mps", self.rank)) + + def to(self, device): + newself = super().to(device) + newself.model = newself.model.to(device) + newself.h0 = newself.h0.to(device) + newself.rank = next(newself.parameters()).get_device() + return newself + + def forward(self, in_seq_emb): + ''' forward model + + Parameters + ---------- + in_seq_emb : Tensor, shape (batch_size, max_len, input_dim) + + Returns + ------- + hidden_seq_emb : Tensor, shape (batch_size, max_len, 1 + bidirectional, hidden_dim) + ''' + # Kishi: I think original MHG had this init_hidden() + self.init_hidden(in_seq_emb.size(0)) + max_len = in_seq_emb.size(1) + hidden_seq_emb, self.h0 = self.model( + in_seq_emb, self.h0) + # As shown as returns, convert hidden_seq_emb: (batch_size, seq_len, (1 or 2) * hidden_size) --> + # (batch_size, seq_len, 1 or 2, hidden_size) + # In the original input the original GRU/LSTM with bidirectional encoding + # has contactinated tensors + # (first half for forward RNN, latter half for backward RNN) + # so convert them in a more friendly format packed for each RNN + hidden_seq_emb = hidden_seq_emb.view(-1, + max_len, + 1 + self.bidirectional, + self.hidden_dim) + return hidden_seq_emb + + +class GRUDecoder(DecoderBase): + + def __init__(self, input_dim: int, hidden_dim: int, num_layers: int, + dropout: float, batch_size: int, rank: int=-1, + no_dropout: bool=False): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.num_layers = num_layers + self.dropout = dropout + self.batch_size = batch_size + self.rank = rank + self.model = nn.GRU(input_size=self.input_dim, + hidden_size=self.hidden_dim, + num_layers=self.num_layers, + batch_first=True, + bidirectional=False, + dropout=self.dropout if not no_dropout else 0 + ) + if self.rank >= 0: + if torch.cuda.is_available(): + self.model = self.model.to(self.rank) + else: + # support mac mps + self.model = self.model.to(torch.device("mps", self.rank)) + self.init_hidden(self.batch_size) + + def init_hidden(self, bsize): + self.hidden_dict['h'] = torch.zeros((self.num_layers, + min(self.batch_size, bsize), + self.hidden_dim), + requires_grad=False) + if self.rank >= 0: + if torch.cuda.is_available(): + self.hidden_dict['h'] = self.hidden_dict['h'].to(self.rank) + else: + self.hidden_dict['h'] = self.hidden_dict['h'].to(torch.device("mps", self.rank)) + + def to(self, device): + newself = super().to(device) + newself.model = newself.model.to(device) + for k in self.hidden_dict.keys(): + newself.hidden_dict[k] = newself.hidden_dict[k].to(device) + newself.rank = next(newself.parameters()).get_device() + return newself + + def forward_one_step(self, tgt_emb_in): + ''' one-step forward model + + Parameters + ---------- + tgt_emb_in : Tensor, shape (batch_size, input_dim) + + Returns + ------- + Tensor, shape (batch_size, hidden_dim) + ''' + bsize = tgt_emb_in.size(0) + tgt_emb_out, self.hidden_dict['h'] \ + = self.model(tgt_emb_in.view(bsize, 1, -1), + self.hidden_dict['h']) + return tgt_emb_out + + +class NodeMLP(nn.Module): + def __init__(self, input_size, output_size, hidden_size): + super().__init__() + self.lin1 = nn.Linear(input_size, hidden_size) + self.nbat = nn.BatchNorm1d(hidden_size) + self.lin2 = nn.Linear(hidden_size, output_size) + + def forward(self, x): + x = self.lin1(x) + x = self.nbat(x) + x = x.relu() + x = self.lin2(x) + return x + + +class GINLayer(MessagePassing): + def __init__(self, node_input_size, node_output_size, node_hidden_size, edge_input_size): + super().__init__() + self.node_mlp = NodeMLP(node_input_size, node_output_size, node_hidden_size) + self.edge_mlp = FeatureEmbedding(edge_input_size, node_output_size) + self.eps = nn.Parameter(torch.tensor([0.0])) + + def forward(self, x, edge_index, edge_attr): + msg = self.propagate(edge_index, x=x ,edge_attr=edge_attr) + x = (1.0 + self.eps) * x + msg + x = x.relu() + x = self.node_mlp(x) + return x + + def message(self, x_j, edge_attr): + edge_attr = self.edge_mlp(edge_attr) + x_j = x_j + edge_attr + x_j = x_j.relu() + return x_j + + def update(self, aggr_out): + return aggr_out + +#TODO implement the case where features of atoms and edges are considered +# Check GraphMVP and ogb (open graph benchmark) to realize this +class GIN(torch.nn.Module): + def __init__(self, node_feature_size, edge_feature_size, hidden_channels=64, + proximity_size=3, dropout=0.1): + super().__init__() + #print("(num node features, num edge features)=", (node_feature_size, edge_feature_size)) + hsize = hidden_channels * 2 + atom_dim, edge_dim = get_atom_edge_feature_dims() + self.trans = FeatureEmbedding(atom_dim, hidden_channels) + ml = [] + for _ in range(proximity_size): + ml.append(GINLayer(hidden_channels, hidden_channels, hsize, edge_dim)) + self.mlist = nn.ModuleList(ml) + #It is possible to calculate relu with x.relu() where x is an output + #self.activations = nn.ModuleList(actl) + self.dropout = dropout + self.proximity_size = proximity_size + + def forward(self, x, edge_index, edge_attr, batch_size): + x = x.to(torch.float) + #print("before: edge_weight.shape=", edge_attr.shape) + edge_attr = edge_attr.to(torch.float) + #print("after: edge_weight.shape=", edge_attr.shape) + x = self.trans(x) + # TODO Check if this x is consistent with global_add_pool + hlist = [global_add_pool(x, batch_size)] + for id, m in enumerate(self.mlist): + x = m(x, edge_index=edge_index, edge_attr=edge_attr) + #print("Done with one layer") + ###if id != self.proximity_size - 1: + x = x.relu() + x = F.dropout(x, p=self.dropout, training=self.training) + #h = global_mean_pool(x, batch_size) + h = global_add_pool(x, batch_size) + hlist.append(h) + #print("Done with one relu call: x.shape=", x.shape) + #print("calling golbal mean pool") + #print("calling dropout x.shape=", x.shape) + #print("x=", x) + #print("hlist[0].shape=", hlist[0].shape) + x = torch.cat(hlist, dim=1) + #print("x.shape=", x.shape) + x = F.dropout(x, p=self.dropout, training=self.training) + + return x + + +# TODO copied from MHG implementation and adapted here. +class GrammarSeq2SeqVAE(nn.Module): + + ''' + Variational seq2seq with grammar. + TODO: rewrite this class using mixin + ''' + + def __init__(self, hrg, rank=-1, latent_dim=64, max_len=80, + batch_size=64, padding_idx=-1, + encoder_params={'hidden_dim': 384, 'num_layers': 3, 'bidirectional': True, + 'dropout': 0.1}, + decoder_params={'hidden_dim': 384, #'num_layers': 2, + 'num_layers': 3, + 'dropout': 0.1}, + prod_rule_embed_params={'out_dim': 128}, + no_dropout=False): + + super().__init__() + # TODO USE GRU FOR ENCODING AND DECODING + self.hrg = hrg + self.rank = rank + self.prod_rule_corpus = hrg.prod_rule_corpus + self.prod_rule_embed_params = prod_rule_embed_params + + self.vocab_size = hrg.num_prod_rule + 1 + self.batch_size = batch_size + self.padding_idx = np.mod(padding_idx, self.vocab_size) + self.no_dropout = no_dropout + + self.latent_dim = latent_dim + self.max_len = max_len + self.encoder_params = encoder_params + self.decoder_params = decoder_params + + # TODO Simple embedding is used. Check if a domain-dependent embedding works or not. + embed_out_dim = self.prod_rule_embed_params['out_dim'] + #use MolecularProdRuleEmbedding later on + self.src_embedding = nn.Embedding(self.vocab_size, embed_out_dim, + padding_idx=self.padding_idx) + self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim, + padding_idx=self.padding_idx) + + # USE a GRU-based encoder in MHG + self.encoder = GRUEncoder(input_dim=embed_out_dim, batch_size=self.batch_size, + rank=self.rank, no_dropout=self.no_dropout, + **self.encoder_params) + + lin_dim = (self.encoder_params.get('bidirectional', False) + 1) * self.encoder_params['hidden_dim'] + lin_out_dim = self.latent_dim + self.hidden2mean = nn.Linear(lin_dim, lin_out_dim, bias=False) + self.hidden2logvar = nn.Linear(lin_dim, lin_out_dim) + + # USE a GRU-based decoder in MHG + self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size, + rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params) + self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim) + self.latent2hidden_dict = nn.ModuleDict() + dec_lin_out_dim = self.decoder_params['hidden_dim'] + for each_hidden in self.decoder.hidden_dict.keys(): + self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, dec_lin_out_dim) + if self.rank >= 0: + if torch.cuda.is_available(): + self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank) + else: + # support mac mps + self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank)) + + self.dec2vocab = nn.Linear(dec_lin_out_dim, self.vocab_size) + self.encoder.init_hidden(self.batch_size) + self.decoder.init_hidden(self.batch_size) + + # TODO Do we need this? + if hasattr(self.src_embedding, 'weight'): + self.src_embedding.weight.data.uniform_(-0.1, 0.1) + if hasattr(self.tgt_embedding, 'weight'): + self.tgt_embedding.weight.data.uniform_(-0.1, 0.1) + + self.encoder.init_hidden(self.batch_size) + self.decoder.init_hidden(self.batch_size) + + def to(self, device): + newself = super().to(device) + newself.src_embedding = newself.src_embedding.to(device) + newself.tgt_embedding = newself.tgt_embedding.to(device) + newself.encoder = newself.encoder.to(device) + newself.decoder = newself.decoder.to(device) + newself.dec2vocab = newself.dec2vocab.to(device) + newself.hidden2mean = newself.hidden2mean.to(device) + newself.hidden2logvar = newself.hidden2logvar.to(device) + newself.latent2tgt_emb = newself.latent2tgt_emb.to(device) + newself.latent2hidden_dict = newself.latent2hidden_dict.to(device) + return newself + + def forward(self, in_seq, out_seq): + ''' forward model + + Parameters + ---------- + in_seq : Variable, shape (batch_size, length) + each element corresponds to word index. + where the index should be less than `vocab_size` + + Returns + ------- + Variable, shape (batch_size, length, vocab_size) + logit of each word (applying softmax yields the probability) + ''' + mu, logvar = self.encode(in_seq) + z = self.reparameterize(mu, logvar) + return self.decode(z, out_seq), mu, logvar + + def encode(self, in_seq): + src_emb = self.src_embedding(in_seq) + src_h = self.encoder.forward(src_emb) + if self.encoder_params.get('bidirectional', False): + concat_src_h = torch.cat((src_h[:, -1, 0, :], src_h[:, 0, 1, :]), dim=1) + return self.hidden2mean(concat_src_h), self.hidden2logvar(concat_src_h) + else: + return self.hidden2mean(src_h[:, -1, :]), self.hidden2logvar(src_h[:, -1, :]) + + def reparameterize(self, mu, logvar, training=True): + if training: + std = logvar.mul(0.5).exp_() + device = next(self.parameters()).device + eps = Variable(std.data.new(std.size()).normal_()) + if device != eps.get_device(): + eps.to(device) + return eps.mul(std).add_(mu) + else: + return mu + + #TODO Not tested. Need to implement this in case of molecular structure generation + def sample(self, sample_size=-1, deterministic=True, return_z=False): + self.eval() + self.init_hidden() + if sample_size == -1: + sample_size = self.batch_size + + num_iter = int(np.ceil(sample_size / self.batch_size)) + hg_list = [] + z_list = [] + for _ in range(num_iter): + z = Variable(torch.normal( + torch.zeros(self.batch_size, self.latent_dim), + torch.ones(self.batch_size * self.latent_dim))).cuda() + _, each_hg_list = self.decode(z, deterministic=deterministic) + z_list.append(z) + hg_list += each_hg_list + z = torch.cat(z_list)[:sample_size] + hg_list = hg_list[:sample_size] + if return_z: + return hg_list, z.cpu().detach().numpy() + else: + return hg_list + + def decode(self, z=None, out_seq=None, deterministic=True): + if z is None: + z = Variable(torch.normal( + torch.zeros(self.batch_size, self.latent_dim), + torch.ones(self.batch_size * self.latent_dim))) + if self.rank >= 0: + z = z.to(next(self.parameters()).device) + + hidden_dict_0 = {} + for each_hidden in self.latent2hidden_dict.keys(): + hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z) + bsize = z.size(0) + self.decoder.init_hidden(bsize) + self.decoder.feed_hidden(hidden_dict_0) + + if out_seq is not None: + tgt_emb0 = self.latent2tgt_emb(z) + tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1]) + out_seq_emb = self.tgt_embedding(out_seq) + tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :] + tgt_emb_pred_list = [] + for each_idx in range(self.max_len): + tgt_emb_pred = self.decoder.forward_one_step(tgt_emb[:, each_idx, :].view(bsize, 1, -1)) + tgt_emb_pred_list.append(tgt_emb_pred) + vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1)) + return vocab_logit + else: + with torch.no_grad(): + tgt_emb = self.latent2tgt_emb(z) + tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1]) + tgt_emb_pred_list = [] + stack_list = [] + hg_list = [] + nt_symbol_list = [] + nt_edge_list = [] + gen_finish_list = [] + for _ in range(bsize): + stack_list.append([]) + hg_list.append(None) + nt_symbol_list.append(NTSymbol(degree=0, + is_aromatic=False, + bond_symbol_list=[])) + nt_edge_list.append(None) + gen_finish_list.append(False) + + for idx in range(self.max_len): + tgt_emb_pred = self.decoder.forward_one_step(tgt_emb) + tgt_emb_pred_list.append(tgt_emb_pred) + vocab_logit = self.dec2vocab(tgt_emb_pred) + for each_batch_idx in range(bsize): + if not gen_finish_list[each_batch_idx]: # if generation has not finished + # get production rule greedily + prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(), + nt_symbol_list[each_batch_idx], + deterministic=deterministic) + # convert production rule into an index + tgt_id = self.hrg.prod_rule_list.index(prod_rule) + # apply the production rule + hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx]) + # add non-terminals to the stack + stack_list[each_batch_idx].extend(nt_edges[::-1]) + # if the stack size is 0, generation has finished! + if len(stack_list[each_batch_idx]) == 0: + gen_finish_list[each_batch_idx] = True + else: + nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop() + nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol'] + else: + tgt_id = np.mod(self.padding_idx, self.vocab_size) + indice_tensor = torch.LongTensor([tgt_id]) + device = next(self.parameters()).device + if indice_tensor.device != device: + indice_tensor = indice_tensor.to(device) + tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor) + vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1)) + #for id, v in enumerate(gen_finish_list): + #if not v: + # print("bacth id={} not finished generating a sequence: ".format(id)) + return gen_finish_list, vocab_logit, hg_list + + +# TODO A lot of duplicates with GrammarVAE. Clean up it if necessary +class GrammarGINVAE(nn.Module): + + ''' + Variational autoencoder based on GIN and grammar + ''' + + def __init__(self, hrg, rank=-1, max_len=80, + batch_size=64, padding_idx=-1, + encoder_params={'node_feature_size': 4, 'edge_feature_size': 3, + 'hidden_channels': 64, 'proximity_size': 3, + 'dropout': 0.1}, + decoder_params={'hidden_dim': 384, 'num_layers': 3, + 'dropout': 0.1}, + prod_rule_embed_params={'out_dim': 128}, + no_dropout=False): + + super().__init__() + # TODO USE GRU FOR ENCODING AND DECODING + self.hrg = hrg + self.rank = rank + self.prod_rule_corpus = hrg.prod_rule_corpus + self.prod_rule_embed_params = prod_rule_embed_params + + self.vocab_size = hrg.num_prod_rule + 1 + self.batch_size = batch_size + self.padding_idx = np.mod(padding_idx, self.vocab_size) + self.no_dropout = no_dropout + self.max_len = max_len + self.encoder_params = encoder_params + self.decoder_params = decoder_params + + # TODO Simple embedding is used. Check if a domain-dependent embedding works or not. + embed_out_dim = self.prod_rule_embed_params['out_dim'] + #use MolecularProdRuleEmbedding later on + self.tgt_embedding = nn.Embedding(self.vocab_size, embed_out_dim, + padding_idx=self.padding_idx) + + self.encoder = GIN(**self.encoder_params) + self.latent_dim = self.encoder_params['hidden_channels'] + self.proximity_size = self.encoder_params['proximity_size'] + hidden_dim = self.decoder_params['hidden_dim'] + self.hidden2mean = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim, bias=False) + self.hidden2logvar = nn.Linear(self.latent_dim * (1 + self.proximity_size), self.latent_dim) + + self.decoder = GRUDecoder(input_dim=embed_out_dim, batch_size=self.batch_size, + rank=self.rank, no_dropout=self.no_dropout, **self.decoder_params) + self.latent2tgt_emb = nn.Linear(self.latent_dim, embed_out_dim) + self.latent2hidden_dict = nn.ModuleDict() + for each_hidden in self.decoder.hidden_dict.keys(): + self.latent2hidden_dict[each_hidden] = nn.Linear(self.latent_dim, hidden_dim) + if self.rank >= 0: + if torch.cuda.is_available(): + self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(self.rank) + else: + # support mac mps + self.latent2hidden_dict[each_hidden] = self.latent2hidden_dict[each_hidden].to(torch.device("mps", self.rank)) + + self.dec2vocab = nn.Linear(hidden_dim, self.vocab_size) + self.decoder.init_hidden(self.batch_size) + + # TODO Do we need this? + if hasattr(self.tgt_embedding, 'weight'): + self.tgt_embedding.weight.data.uniform_(-0.1, 0.1) + self.decoder.init_hidden(self.batch_size) + + def to(self, device): + newself = super().to(device) + newself.encoder = newself.encoder.to(device) + newself.decoder = newself.decoder.to(device) + newself.rank = next(newself.encoder.parameters()).get_device() + return newself + + def forward(self, x, edge_index, edge_attr, batch_size, out_seq=None, sched_prob = None): + mu, logvar = self.encode(x, edge_index, edge_attr, batch_size) + z = self.reparameterize(mu, logvar) + return self.decode(z, out_seq, sched_prob=sched_prob), mu, logvar + + #TODO Not tested. Need to implement this in case of molecular structure generation + def sample(self, sample_size=-1, deterministic=True, return_z=False): + self.eval() + self.init_hidden() + if sample_size == -1: + sample_size = self.batch_size + + num_iter = int(np.ceil(sample_size / self.batch_size)) + hg_list = [] + z_list = [] + for _ in range(num_iter): + z = Variable(torch.normal( + torch.zeros(self.batch_size, self.latent_dim), + torch.ones(self.batch_size * self.latent_dim))).cuda() + _, each_hg_list = self.decode(z, deterministic=deterministic) + z_list.append(z) + hg_list += each_hg_list + z = torch.cat(z_list)[:sample_size] + hg_list = hg_list[:sample_size] + if return_z: + return hg_list, z.cpu().detach().numpy() + else: + return hg_list + + def decode(self, z=None, out_seq=None, deterministic=True, sched_prob=None): + if z is None: + z = Variable(torch.normal( + torch.zeros(self.batch_size, self.latent_dim), + torch.ones(self.batch_size * self.latent_dim))) + if self.rank >= 0: + z = z.to(next(self.parameters()).device) + + hidden_dict_0 = {} + for each_hidden in self.latent2hidden_dict.keys(): + hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z) + bsize = z.size(0) + self.decoder.init_hidden(bsize) + self.decoder.feed_hidden(hidden_dict_0) + + if out_seq is not None: + tgt_emb0 = self.latent2tgt_emb(z) + tgt_emb0 = tgt_emb0.view(tgt_emb0.shape[0], 1, tgt_emb0.shape[1]) + out_seq_emb = self.tgt_embedding(out_seq) + tgt_emb = torch.cat((tgt_emb0, out_seq_emb), dim=1)[:, :-1, :] + tgt_emb_pred_list = [] + tgt_emb_pred = None + for each_idx in range(self.max_len): + if tgt_emb_pred is None or sched_prob is None or torch.rand(1)[0] <= sched_prob: + inp = tgt_emb[:, each_idx, :].view(bsize, 1, -1) + else: + cur_logit = self.dec2vocab(tgt_emb_pred) + yi = torch.argmax(cur_logit, dim=2) + inp = self.tgt_embedding(yi) + tgt_emb_pred = self.decoder.forward_one_step(inp) + tgt_emb_pred_list.append(tgt_emb_pred) + vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1)) + return vocab_logit + else: + with torch.no_grad(): + tgt_emb = self.latent2tgt_emb(z) + tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1]) + tgt_emb_pred_list = [] + stack_list = [] + hg_list = [] + nt_symbol_list = [] + nt_edge_list = [] + gen_finish_list = [] + for _ in range(bsize): + stack_list.append([]) + hg_list.append(None) + nt_symbol_list.append(NTSymbol(degree=0, + is_aromatic=False, + bond_symbol_list=[])) + nt_edge_list.append(None) + gen_finish_list.append(False) + + for _ in range(self.max_len): + tgt_emb_pred = self.decoder.forward_one_step(tgt_emb) + tgt_emb_pred_list.append(tgt_emb_pred) + vocab_logit = self.dec2vocab(tgt_emb_pred) + for each_batch_idx in range(bsize): + if not gen_finish_list[each_batch_idx]: # if generation has not finished + # get production rule greedily + prod_rule = self.hrg.prod_rule_corpus.sample(vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(), + nt_symbol_list[each_batch_idx], + deterministic=deterministic) + # convert production rule into an index + tgt_id = self.hrg.prod_rule_list.index(prod_rule) + # apply the production rule + hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx]) + # add non-terminals to the stack + stack_list[each_batch_idx].extend(nt_edges[::-1]) + # if the stack size is 0, generation has finished! + if len(stack_list[each_batch_idx]) == 0: + gen_finish_list[each_batch_idx] = True + else: + nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop() + nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol'] + else: + tgt_id = np.mod(self.padding_idx, self.vocab_size) + indice_tensor = torch.LongTensor([tgt_id]) + if self.rank >= 0: + indice_tensor = indice_tensor.to(next(self.parameters()).device) + tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor) + vocab_logit = self.dec2vocab(torch.cat(tgt_emb_pred_list, dim=1)) + return gen_finish_list, vocab_logit, hg_list + + #TODO Not tested. Need to implement this in case of molecular structure generation + def conditional_distribution(self, z, tgt_id_list): + self.eval() + self.init_hidden() + z = z.cuda() + + hidden_dict_0 = {} + for each_hidden in self.latent2hidden_dict.keys(): + hidden_dict_0[each_hidden] = self.latent2hidden_dict[each_hidden](z) + self.decoder.feed_hidden(hidden_dict_0) + + with torch.no_grad(): + tgt_emb = self.latent2tgt_emb(z) + tgt_emb = tgt_emb.view(tgt_emb.shape[0], 1, tgt_emb.shape[1]) + nt_symbol_list = [] + stack_list = [] + hg_list = [] + nt_edge_list = [] + gen_finish_list = [] + for _ in range(self.batch_size): + nt_symbol_list.append(NTSymbol(degree=0, + is_aromatic=False, + bond_symbol_list=[])) + stack_list.append([]) + hg_list.append(None) + nt_edge_list.append(None) + gen_finish_list.append(False) + + for each_position in range(len(tgt_id_list[0])): + tgt_emb_pred = self.decoder.forward_one_step(tgt_emb) + for each_batch_idx in range(self.batch_size): + if not gen_finish_list[each_batch_idx]: # if generation has not finished + # use the prespecified target ids + tgt_id = tgt_id_list[each_batch_idx][each_position] + prod_rule = self.hrg.prod_rule_list[tgt_id] + # apply the production rule + hg_list[each_batch_idx], nt_edges = prod_rule.applied_to(hg_list[each_batch_idx], nt_edge_list[each_batch_idx]) + # add non-terminals to the stack + stack_list[each_batch_idx].extend(nt_edges[::-1]) + # if the stack size is 0, generation has finished! + if len(stack_list[each_batch_idx]) == 0: + gen_finish_list[each_batch_idx] = True + else: + nt_edge_list[each_batch_idx] = stack_list[each_batch_idx].pop() + nt_symbol_list[each_batch_idx] = hg_list[each_batch_idx].edge_attr(nt_edge_list[each_batch_idx])['symbol'] + else: + tgt_id = np.mod(self.padding_idx, self.vocab_size) + indice_tensor = torch.LongTensor([tgt_id]) + indice_tensor = indice_tensor.cuda() + tgt_emb[each_batch_idx, :] = self.tgt_embedding(indice_tensor) + + # last one step + conditional_logprob_list = [] + tgt_emb_pred = self.decoder.forward_one_step(tgt_emb) + vocab_logit = self.dec2vocab(tgt_emb_pred) + for each_batch_idx in range(self.batch_size): + if not gen_finish_list[each_batch_idx]: # if generation has not finished + # get production rule greedily + masked_logprob = self.hrg.prod_rule_corpus.masked_logprob( + vocab_logit[each_batch_idx, :, :-1].squeeze().cpu().numpy(), + nt_symbol_list[each_batch_idx]) + conditional_logprob_list.append(masked_logprob) + else: + conditional_logprob_list.append(None) + return conditional_logprob_list + + #TODO Not tested. Need to implement this in case of molecular structure generation + def decode_with_beam_search(self, z, beam_width=1): + ''' Decode a latent vector using beam search. + + Parameters + ---------- + z + latent vector + beam_width : int + parameter for beam search + + Returns + ------- + List of Hypergraphs + ''' + if self.batch_size != 1: + raise ValueError('this method works only under batch_size=1') + if self.padding_idx != -1: + raise ValueError('this method works only under padding_idx=-1') + top_k_tgt_id_list = [[]] * beam_width + logprob_list = [0.] * beam_width + + for each_len in range(self.max_len): + expanded_logprob_list = np.repeat(logprob_list, self.vocab_size) # including padding_idx + expanded_length_list = np.array([0] * (beam_width * self.vocab_size)) + for each_beam_idx, each_candidate in enumerate(top_k_tgt_id_list): + conditional_logprob = self.conditional_distribution(z, [each_candidate])[0] + if conditional_logprob is None: + expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\ + = logprob_list[each_beam_idx] + expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\ + = -np.inf + expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\ + = len(each_candidate) + else: + expanded_logprob_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size - 1]\ + = logprob_list[each_beam_idx] + conditional_logprob + expanded_logprob_list[(each_beam_idx + 1) * self.vocab_size - 1]\ + = -np.inf + expanded_length_list[each_beam_idx * self.vocab_size : (each_beam_idx + 1) * self.vocab_size]\ + = len(each_candidate) + 1 + score_list = np.array(expanded_logprob_list) / np.array(expanded_length_list) + if each_len == 0: + top_k_list = np.argsort(score_list[:self.vocab_size])[::-1][:beam_width] + else: + top_k_list = np.argsort(score_list)[::-1][:beam_width] + next_top_k_tgt_id_list = [] + next_logprob_list = [] + for each_top_k in top_k_list: + beam_idx = each_top_k // self.vocab_size + vocab_idx = each_top_k % self.vocab_size + if vocab_idx == self.vocab_size - 1: + next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx]) + next_logprob_list.append(expanded_logprob_list[each_top_k]) + else: + next_top_k_tgt_id_list.append(top_k_tgt_id_list[beam_idx] + [vocab_idx]) + next_logprob_list.append(expanded_logprob_list[each_top_k]) + top_k_tgt_id_list = next_top_k_tgt_id_list + logprob_list = next_logprob_list + + # construct hypergraphs + hg_list = [] + for each_tgt_id_list in top_k_tgt_id_list: + hg = None + stack = [] + nt_edge = None + for each_idx, each_prod_rule_id in enumerate(each_tgt_id_list): + prod_rule = self.hrg.prod_rule_list[each_prod_rule_id] + hg, nt_edges = prod_rule.applied_to(hg, nt_edge) + stack.extend(nt_edges[::-1]) + try: + nt_edge = stack.pop() + except IndexError: + if each_idx == len(each_tgt_id_list) - 1: + break + else: + raise ValueError('some bugs') + hg_list.append(hg) + return hg_list + + def graph_embed(self, x, edge_index, edge_attr, batch_size): + src_h = self.encoder.forward(x, edge_index, edge_attr, batch_size) + return src_h + + def encode(self, x, edge_index, edge_attr, batch_size): + #print("device for src_emb=", src_emb.get_device()) + #print("device for self.encoder=", next(self.encoder.parameters()).get_device()) + src_h = self.graph_embed(x, edge_index, edge_attr, batch_size) + mu, lv = self.get_mean_var(src_h) + return mu, lv + + def get_mean_var(self, src_h): + #src_h = torch.tanh(src_h) + mu = self.hidden2mean(src_h) + lv = self.hidden2logvar(src_h) + mu = torch.tanh(mu) + lv = torch.tanh(lv) + return mu, lv + + def reparameterize(self, mu, logvar, training=True): + if training: + std = logvar.mul(0.5).exp_() + eps = Variable(std.data.new(std.size()).normal_()) + if self.rank >= 0: + eps = eps.to(next(self.parameters()).device) + return eps.mul(std).add_(mu) + else: + return mu + +# Copied from the MHG implementation and adapted +class GrammarVAELoss(_Loss): + + ''' + a loss function for Grammar VAE + + Attributes + ---------- + hrg : HyperedgeReplacementGrammar + beta : float + coefficient of KL divergence + ''' + + def __init__(self, rank, hrg, beta=1.0, **kwargs): + super().__init__(**kwargs) + self.hrg = hrg + self.beta = beta + self.rank = rank + + def forward(self, mu, logvar, in_seq_pred, in_seq): + ''' compute VAE loss + + Parameters + ---------- + in_seq_pred : torch.Tensor, shape (batch_size, max_len, vocab_size) + logit + in_seq : torch.Tensor, shape (batch_size, max_len) + each element corresponds to a word id in vocabulary. + mu : torch.Tensor, shape (batch_size, hidden_dim) + logvar : torch.Tensor, shape (batch_size, hidden_dim) + mean and log variance of the normal distribution + ''' + batch_size = in_seq_pred.shape[0] + max_len = in_seq_pred.shape[1] + vocab_size = in_seq_pred.shape[2] + mask = torch.zeros(in_seq_pred.shape) + + for each_batch in range(batch_size): + flag = True + for each_idx in range(max_len): + prod_rule_idx = in_seq[each_batch, each_idx] + if prod_rule_idx == vocab_size - 1: + #### DETERMINE WHETHER THIS SHOULD BE SKIPPED OR NOT + mask[each_batch, each_idx, prod_rule_idx] = 1 + #break + continue + lhs = self.hrg.prod_rule_corpus.prod_rule_list[prod_rule_idx].lhs_nt_symbol + lhs_idx = self.hrg.prod_rule_corpus.nt_symbol_list.index(lhs) + mask[each_batch, each_idx, :-1] = torch.FloatTensor(self.hrg.prod_rule_corpus.lhs_in_prod_rule[lhs_idx]) + if self.rank >= 0: + mask = mask.to(next(self.parameters()).device) + in_seq_pred = mask * in_seq_pred + + cross_entropy = F.cross_entropy( + in_seq_pred.view(-1, vocab_size), + in_seq.view(-1), + reduction='sum', + #ignore_index=self.ignore_index if self.ignore_index is not None else -100 + ) + kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + return cross_entropy + self.beta * kl_div + + +class VAELoss(_Loss): + def __init__(self, beta=0.01): + super().__init__() + self.beta = beta + + def forward(self, mean, log_var, dec_outputs, targets): + + device = mean.get_device() + if device >= 0: + targets = targets.to(mean.get_device()) + reconstruction = F.cross_entropy(dec_outputs.view(-1, dec_outputs.size(2)), targets.view(-1), reduction='sum') + + KL = 0.5 * torch.sum(1 + log_var - mean ** 2 - torch.exp(log_var)) + loss = - self.beta * KL + reconstruction + return loss diff --git a/notebooks/mhg-gnn_encoder_decoder_example.ipynb b/notebooks/mhg-gnn_encoder_decoder_example.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..84ff55ec04adaa87452caec876def45757592879 --- /dev/null +++ b/notebooks/mhg-gnn_encoder_decoder_example.ipynb @@ -0,0 +1,114 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "829ddc03", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('..')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea820e23", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import load" + ] + }, + { + "cell_type": "markdown", + "id": "b9a51fa8", + "metadata": {}, + "source": [ + "# Load MHG-GNN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c6ea1fc8", + "metadata": {}, + "outputs": [], + "source": [ + "model_ckp = \"models/model_checkpoints/mhg_model/pickles/mhggnn_pretrained_model_radius7_1116_2023.pickle\"\n", + "\n", + "model = load.load(model_name = model_ckp)\n", + "if model is None:\n", + " print(\"Model not loaded, please check you have MHG pickle file\")\n", + "else:\n", + " print(\"MHG model loaded\")" + ] + }, + { + "cell_type": "markdown", + "id": "b4a0b557", + "metadata": {}, + "source": [ + "# Embeddings\n", + "\n", + "※ replace the smiles exaple list with your dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c63a6be6", + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " repr = model.encode([\"CCO\", \"O=C=O\", \"OC(=O)c1ccccc1C(=O)O\"])\n", + " \n", + "# Print the latent vectors\n", + "print(repr)" + ] + }, + { + "cell_type": "markdown", + "id": "a59f9442", + "metadata": {}, + "source": [ + "# Decoding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a0d8a41", + "metadata": {}, + "outputs": [], + "source": [ + "orig = model.decode(repr)\n", + "print(orig)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf b/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf new file mode 100644 index 0000000000000000000000000000000000000000..a7dcc1270d1f444f77366013ad2d3d93ebb426ab Binary files /dev/null and b/paper/MHG-GNN_Combination of Molecular Hypergraph Grammar with Graph Neural Network.pdf differ diff --git a/pickles/.DS_Store b/pickles/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5008ddfcf53c02e82d7eee2e57c38e5672ef89f6 Binary files /dev/null and b/pickles/.DS_Store differ diff --git a/pickles/mhggnn_pretrained_model_0724_2023.pickle b/pickles/mhggnn_pretrained_model_0724_2023.pickle new file mode 100644 index 0000000000000000000000000000000000000000..24d5ec31658ddc4d59876cc69996c399727a4e7e --- /dev/null +++ b/pickles/mhggnn_pretrained_model_0724_2023.pickle @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0e5fd6991c7f3a1791c0ba5b2e3805cb6f8f3f734cfdeda145d99cceed4a8533 +size 261888810 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..12d7e158b8963bb101610dcb0a81fd3cc04eaae0 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,37 @@ +[metadata] +name = mhg-gnn +version = attr: .__version__ +description = Package for mhg-gnn +author= team +long_description_content_type=text/markdown +long_description = file: README.md +python_requires = >= 3.9.7 +license = TBD + +classifiers = + Programming Language :: Python :: 3 + Programming Language :: Python :: 3.9 + +[options] +install_requires = + networkx>=2.8 + numpy>=1.23.5, <2.0.0 + pandas>=1.5.3 + rdkit-pypi>=2022.9.4, <2023.9.6 + torch>=2.0.0 + torchinfo>=1.8.0 + torch-geometric>=2.3.1 + requests>=2.32.2 + scikit-learn>=1.5.0 + urllib3>=2.2.2 + + +setup_requires = + setuptools +package_dir = + = . +packages=find: +include_package_data = True + +[options.packages.find] +where = . diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..45f160da695819ef6906f6dd332e8398cf419b8e --- /dev/null +++ b/setup.py @@ -0,0 +1,6 @@ +#!/usr/bin/env python + +import setuptools + +if __name__ == "__main__": + setuptools.setup() \ No newline at end of file