Spaces:
Runtime error
Runtime error
File size: 9,192 Bytes
95f97c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import random
import os
import numpy as np
import argparse
import json
from collections import defaultdict
from matplotlib import pyplot as plt
from collections import Counter
from .data_utils import json_read
def set_random_seed(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
class Reaction_Cluster:
def __init__(self, root, reaction_filename, reverse_ratio=0.5):
self.root = root
self.reaction_data = json_read(os.path.join(self.root, reaction_filename))
self.property_data = json_read(os.path.join(self.root, 'Abstract_property.json'))
self.mol_property_map = {d['canon_smiles']: d for d in self.property_data}
self.reverse_ratio = reverse_ratio
self.rxn_mols_attr = defaultdict(lambda:{
'freq': 0,
'occurrence': 0,
'in_caption': False,
})
self._read_reaction_mols() # add `valid_mols` in each rxn_dict
self.mol_counter = Counter(mol for rxn_dict in self.reaction_data for mol in rxn_dict['valid_mols'])
self._calculate_Pr() # calculate P(r), add `weight` in each rxn_dict
self._calculate_Pir() # calculate P(i|r), add `mol_weight` in each rxn_dict
def _read_reaction_mols(self):
self.valid_rxn_indices = []
for rxn_id, rxn_dict in enumerate(self.reaction_data):
mol_role_map = {}
for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']:
for m in rxn_dict[key]:
if m in mol_role_map:
continue
if m in self.mol_property_map:
mol_role_map[m] = key
valid_mols = []
for mol in mol_role_map:
assert mol in self.mol_property_map # this is garanteed by the above if statement
if 'abstract' not in self.mol_property_map[mol]:
continue
valid_mols.append(mol) # here the molecules should be in the R, C, S, P order.
if len(valid_mols) > 0:
self.valid_rxn_indices.append(rxn_id)
rxn_dict['valid_mols'] = valid_mols
rxn_dict['mol_role_map'] = mol_role_map
def _calculate_Pr(self):
total_weights = 0
for rxn_dict in self.reaction_data:
rxn_weight = sum([1/self.mol_counter[mol] for mol in rxn_dict['valid_mols']])
rxn_dict['weight'] = rxn_weight
total_weights += rxn_weight
for rxn_dict in self.reaction_data:
rxn_dict['weight'] = rxn_dict['weight'] / total_weights
def _calculate_Pir(self):
for rxn_dict in self.reaction_data:
mol_weight = {}
for mol in rxn_dict['valid_mols']:
mol_weight[mol] = 1/self.mol_counter[mol]
total_weight = sum(mol_weight.values())
rxn_dict['mol_weight'] = {m:w/total_weight for m, w in mol_weight.items()}
def choose_mol(self, valid_mols, k=4, weights=None):
if k>=len(valid_mols):
sampled_indices = list(range(len(valid_mols)))
else:
sampled_indices = np.random.choice(len(valid_mols), k, replace=False, p=weights)
sampled_indices = list(sampled_indices)
sampled_indices = sorted(sampled_indices)
if random.random() < self.reverse_ratio: # reverse the indices with reverse_ratio chance.
sampled_indices.reverse()
sampled_mols = [valid_mols[i] for i in sampled_indices]
return sampled_mols
def sample_mol_batch(self, index=None, k=4):
if index is None:
index = self.sample_rxn_index(1)[0]
assert index < len(self.reaction_data)
rxn = self.reaction_data[index]
valid_mols, weights = zip(*rxn['mol_weight'].items())
sampled_mols = self.choose_mol(valid_mols, k=k, weights=weights)
mol_property_batch = []
for mol in sampled_mols:
mol_property = self.mol_property_map[mol]
mol_role = rxn['mol_role_map'][mol]
mol_property['role'] = mol_role
mol_property_batch.append(mol_property)
if 'rsmiles_map' in rxn:
rsmiles_map = random.choice(rxn['rsmiles_map'])
for mol_property in mol_property_batch:
canon_smiles = mol_property['canon_smiles']
if canon_smiles in rsmiles_map:
mol_property['r_smiles'] = rsmiles_map[canon_smiles]
return mol_property_batch
def sample_rxn_index(self, num_samples):
indices = range(len(self.reaction_data))
weights = [d['weight'] for d in self.reaction_data]
return np.random.choice(indices, num_samples, replace=False, p=weights)
def __call__(self, rxn_num=1000, k=4):
sampled_indices = self.sample_rxn_index(rxn_num)
sampled_batch = [self.sample_mol_batch(idx, k=k) for idx in sampled_indices]
return sampled_batch
def generate_batch_uniform_rxn(self, rxn_num=1000, k=4):
assert rxn_num <= len(self.valid_rxn_indices)
sampled_rxn_indices = random.sample(self.valid_rxn_indices, rxn_num)
sampled_batch = []
for rxn_id in sampled_rxn_indices:
rxn = self.reaction_data[rxn_id]
sampled_mols = self.choose_mol(rxn['valid_mols'], k=k, weights=None)
mol_property_batch = []
for mol in sampled_mols:
mol_property = self.mol_property_map[mol]
mol_role = rxn['mol_role_map'][mol]
mol_property['role'] = mol_role
mol_property_batch.append(mol_property)
sampled_batch.append(mol_property_batch)
return sampled_batch
def generate_batch_uniform_mol(self, rxn_num=1000, k=4):
valid_mols = list(self.mol_counter.elements())
assert rxn_num*k <= len(valid_mols)
sampled_batch = []
sampled_mol_ids = random.sample(range(len(valid_mols)), rxn_num*k)
for i in range(rxn_num):
sampled_batch.append([self.mol_property_map[valid_mols[mol_id]] for mol_id in sampled_mol_ids[i*k:(i+1)*k]])
return sampled_batch
def generate_batch_single(self, rxn_num=1000):
valid_mols = list(self.mol_counter.elements())
sampled_mols = random.sample(valid_mols, rxn_num)
total_valid_mols = [[self.mol_property_map[mol]] for mol in sampled_mols]
return total_valid_mols
# visaulize probability for molecules in caption dataset.
def visualize_mol_distribution(self):
prob_dict = {mol:0.0 for mol in self.mol_property_map.keys()}
N = len(prob_dict)
M = len(self.reaction_data)
assert N == len(self.mol_property_map)
print(f'Number of molecules in Caption Dataset: {N}')
print(f'Number of Reactions in Reaction Dataset: {M}')
# prob distribution for molecules
for rxn_dict in self.reaction_data:
for mol, weight in rxn_dict['mol_weight'].items():
prob_dict[mol] += weight * rxn_dict['weight']
# sum of prob_dict.values() should already be 1.
prob_values = np.array(list(prob_dict.values()))
prob_values *= N
# prob distribution for reactions
rxn_weights = np.array([d['weight'] for d in self.reaction_data])
# sum of rxn_weights should already be 1.
rxn_weights *= M
return prob_values, rxn_weights
# visaulize the frequency for molecules in caption dataset.
def visualize_mol_frequency(self, rxn_num=1000, k=4, epochs=100):
sampled_mols_counter = Counter()
sampled_rxns_counter = Counter()
for _ in range(epochs):
rxn_indices = self.sample_rxn_index(rxn_num)
sampled_rxns_counter.update(rxn_indices)
for index in rxn_indices:
rxn = self.reaction_data[index]
if len(rxn['valid_mols']) ==0:
continue
valid_mols, weights = zip(*rxn['mol_weight'].items())
mol_batch = self.choose_mol(valid_mols, k=k, weights=weights)
sampled_mols_counter.update(mol_batch)
sampled_mols_count = np.array([c for _, c in sorted(sampled_mols_counter.items())])
sampled_rxns_count = np.array([c for _, c in sorted(sampled_rxns_counter.items())])
return sampled_mols_count, sampled_rxns_count
def _randomly(self, func, *args, **kwargs):
# make fake weights and backup the weights
for rxn_dict in self.reaction_data:
rxn_dict['weight_bak'] = rxn_dict['weight']
rxn_dict['weight'] = 1/len(self.reaction_data)
rxn_dict['mol_weight_bak'] = rxn_dict['mol_weight']
rxn_dict['mol_weight'] = {m:1/len(rxn_dict['mol_weight']) for m in rxn_dict['mol_weight']}
# run the function
result = func(*args, **kwargs)
# weights recovery
for rxn_dict in self.reaction_data:
rxn_dict['weight'] = rxn_dict['weight_bak']
del rxn_dict['weight_bak']
rxn_dict['mol_weight'] = rxn_dict['mol_weight_bak']
del rxn_dict['mol_weight_bak']
return result
|