File size: 9,192 Bytes
95f97c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import random
import os
import numpy as np
import argparse
import json
from collections import defaultdict
from matplotlib import pyplot as plt
from collections import Counter
from .data_utils import json_read

def set_random_seed(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)

class Reaction_Cluster:
    def __init__(self, root, reaction_filename, reverse_ratio=0.5):
        self.root = root
        self.reaction_data = json_read(os.path.join(self.root, reaction_filename))
        self.property_data = json_read(os.path.join(self.root, 'Abstract_property.json'))
        self.mol_property_map = {d['canon_smiles']: d for d in self.property_data}
        self.reverse_ratio = reverse_ratio
        self.rxn_mols_attr = defaultdict(lambda:{
            'freq': 0,
            'occurrence': 0,
            'in_caption': False,
        })

        self._read_reaction_mols() # add `valid_mols` in each rxn_dict
        self.mol_counter = Counter(mol for rxn_dict in self.reaction_data for mol in rxn_dict['valid_mols'])
        self._calculate_Pr() # calculate P(r), add `weight` in each rxn_dict
        self._calculate_Pir() # calculate P(i|r), add `mol_weight` in each rxn_dict

    def _read_reaction_mols(self):
        self.valid_rxn_indices = []
        for rxn_id, rxn_dict in enumerate(self.reaction_data):
            mol_role_map = {}
            for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']:
                for m in rxn_dict[key]:
                    if m in mol_role_map:
                        continue
                    if m in self.mol_property_map:
                        mol_role_map[m] = key
            valid_mols = []
            for mol in mol_role_map:
                assert mol in self.mol_property_map # this is garanteed by the above if statement
                if 'abstract' not in self.mol_property_map[mol]:
                    continue
                valid_mols.append(mol) # here the molecules should be in the R, C, S, P order.
            if len(valid_mols) > 0:
                self.valid_rxn_indices.append(rxn_id)
            rxn_dict['valid_mols'] = valid_mols
            rxn_dict['mol_role_map'] = mol_role_map

    def _calculate_Pr(self):
        total_weights = 0
        for rxn_dict in self.reaction_data:
            rxn_weight = sum([1/self.mol_counter[mol] for mol in rxn_dict['valid_mols']])
            rxn_dict['weight'] = rxn_weight
            total_weights += rxn_weight
        for rxn_dict in self.reaction_data:
            rxn_dict['weight'] = rxn_dict['weight'] / total_weights

    def _calculate_Pir(self):
        for rxn_dict in self.reaction_data:
            mol_weight = {}
            for mol in rxn_dict['valid_mols']:
                mol_weight[mol] = 1/self.mol_counter[mol]
            total_weight = sum(mol_weight.values())
            rxn_dict['mol_weight'] = {m:w/total_weight for m, w in mol_weight.items()}

    def choose_mol(self, valid_mols, k=4, weights=None):
        if k>=len(valid_mols):
            sampled_indices = list(range(len(valid_mols)))
        else:
            sampled_indices = np.random.choice(len(valid_mols), k, replace=False, p=weights)
            sampled_indices = list(sampled_indices)
        sampled_indices = sorted(sampled_indices)
        if random.random() < self.reverse_ratio: # reverse the indices with reverse_ratio chance.
            sampled_indices.reverse()
        sampled_mols = [valid_mols[i] for i in sampled_indices]
        return sampled_mols

    def sample_mol_batch(self, index=None, k=4):
        if index is None:
            index = self.sample_rxn_index(1)[0]
        assert index < len(self.reaction_data)
        rxn = self.reaction_data[index]
        valid_mols, weights = zip(*rxn['mol_weight'].items())

        sampled_mols = self.choose_mol(valid_mols, k=k, weights=weights)
        mol_property_batch = []
        for mol in sampled_mols:
            mol_property = self.mol_property_map[mol]
            mol_role = rxn['mol_role_map'][mol]
            mol_property['role'] = mol_role
            mol_property_batch.append(mol_property)
        if 'rsmiles_map' in rxn:
            rsmiles_map = random.choice(rxn['rsmiles_map'])
            for mol_property in mol_property_batch:
                canon_smiles = mol_property['canon_smiles']
                if canon_smiles in rsmiles_map:
                    mol_property['r_smiles'] = rsmiles_map[canon_smiles]
        return mol_property_batch

    def sample_rxn_index(self, num_samples):
        indices = range(len(self.reaction_data))
        weights = [d['weight'] for d in self.reaction_data]
        return np.random.choice(indices, num_samples, replace=False, p=weights)

    def __call__(self, rxn_num=1000, k=4):
        sampled_indices = self.sample_rxn_index(rxn_num)
        sampled_batch = [self.sample_mol_batch(idx, k=k) for idx in sampled_indices]
        return sampled_batch

    def generate_batch_uniform_rxn(self, rxn_num=1000, k=4):
        assert rxn_num <= len(self.valid_rxn_indices)
        sampled_rxn_indices = random.sample(self.valid_rxn_indices, rxn_num)
        sampled_batch = []
        for rxn_id in sampled_rxn_indices:
            rxn = self.reaction_data[rxn_id]
            sampled_mols = self.choose_mol(rxn['valid_mols'], k=k, weights=None)
            mol_property_batch = []
            for mol in sampled_mols:
                mol_property = self.mol_property_map[mol]
                mol_role = rxn['mol_role_map'][mol]
                mol_property['role'] = mol_role
                mol_property_batch.append(mol_property)
            sampled_batch.append(mol_property_batch)
        return sampled_batch

    def generate_batch_uniform_mol(self, rxn_num=1000, k=4):
        valid_mols = list(self.mol_counter.elements())
        assert rxn_num*k <= len(valid_mols)
        sampled_batch = []
        sampled_mol_ids = random.sample(range(len(valid_mols)), rxn_num*k)
        for i in range(rxn_num):
            sampled_batch.append([self.mol_property_map[valid_mols[mol_id]] for mol_id in sampled_mol_ids[i*k:(i+1)*k]])
        return sampled_batch

    def generate_batch_single(self, rxn_num=1000):
        valid_mols = list(self.mol_counter.elements())
        sampled_mols = random.sample(valid_mols, rxn_num)
        total_valid_mols = [[self.mol_property_map[mol]] for mol in sampled_mols]
        return total_valid_mols

    # visaulize probability for molecules in caption dataset.
    def visualize_mol_distribution(self):
        prob_dict = {mol:0.0 for mol in self.mol_property_map.keys()}
        N = len(prob_dict)
        M = len(self.reaction_data)
        assert N == len(self.mol_property_map)
        print(f'Number of molecules in Caption Dataset: {N}')
        print(f'Number of Reactions in Reaction Dataset: {M}')

        # prob distribution for molecules
        for rxn_dict in self.reaction_data:
            for mol, weight in rxn_dict['mol_weight'].items():
                prob_dict[mol] += weight * rxn_dict['weight']
        # sum of prob_dict.values() should already be 1.
        prob_values = np.array(list(prob_dict.values()))
        prob_values *= N

        # prob distribution for reactions
        rxn_weights = np.array([d['weight'] for d in self.reaction_data])
        # sum of rxn_weights should already be 1.
        rxn_weights *= M

        return prob_values, rxn_weights

    # visaulize the frequency for molecules in caption dataset.
    def visualize_mol_frequency(self, rxn_num=1000, k=4, epochs=100):
        sampled_mols_counter = Counter()
        sampled_rxns_counter = Counter()
        for _ in range(epochs):
            rxn_indices = self.sample_rxn_index(rxn_num)
            sampled_rxns_counter.update(rxn_indices)
            for index in rxn_indices:
                rxn = self.reaction_data[index]
                if len(rxn['valid_mols']) ==0:
                    continue
                valid_mols, weights = zip(*rxn['mol_weight'].items())
                mol_batch = self.choose_mol(valid_mols, k=k, weights=weights)
                sampled_mols_counter.update(mol_batch)
        sampled_mols_count = np.array([c for _, c in sorted(sampled_mols_counter.items())])
        sampled_rxns_count = np.array([c for _, c in sorted(sampled_rxns_counter.items())])
        return sampled_mols_count, sampled_rxns_count

    def _randomly(self, func, *args, **kwargs):
        # make fake weights and backup the weights
        for rxn_dict in self.reaction_data:
            rxn_dict['weight_bak'] = rxn_dict['weight']
            rxn_dict['weight'] = 1/len(self.reaction_data)
            rxn_dict['mol_weight_bak'] = rxn_dict['mol_weight']
            rxn_dict['mol_weight'] = {m:1/len(rxn_dict['mol_weight']) for m in rxn_dict['mol_weight']}

        # run the function
        result = func(*args, **kwargs)

        # weights recovery
        for rxn_dict in self.reaction_data:
            rxn_dict['weight'] = rxn_dict['weight_bak']
            del rxn_dict['weight_bak']
            rxn_dict['mol_weight'] = rxn_dict['mol_weight_bak']
            del rxn_dict['mol_weight_bak']

        return result