Spaces:
Runtime error
Runtime error
import random | |
import os | |
import numpy as np | |
import argparse | |
import json | |
from collections import defaultdict | |
from matplotlib import pyplot as plt | |
from collections import Counter | |
from .data_utils import json_read | |
def set_random_seed(seed): | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
class Reaction_Cluster: | |
def __init__(self, root, reaction_filename, reverse_ratio=0.5): | |
self.root = root | |
self.reaction_data = json_read(os.path.join(self.root, reaction_filename)) | |
self.property_data = json_read(os.path.join(self.root, 'Abstract_property.json')) | |
self.mol_property_map = {d['canon_smiles']: d for d in self.property_data} | |
self.reverse_ratio = reverse_ratio | |
self.rxn_mols_attr = defaultdict(lambda:{ | |
'freq': 0, | |
'occurrence': 0, | |
'in_caption': False, | |
}) | |
self._read_reaction_mols() # add `valid_mols` in each rxn_dict | |
self.mol_counter = Counter(mol for rxn_dict in self.reaction_data for mol in rxn_dict['valid_mols']) | |
self._calculate_Pr() # calculate P(r), add `weight` in each rxn_dict | |
self._calculate_Pir() # calculate P(i|r), add `mol_weight` in each rxn_dict | |
def _read_reaction_mols(self): | |
self.valid_rxn_indices = [] | |
for rxn_id, rxn_dict in enumerate(self.reaction_data): | |
mol_role_map = {} | |
for key in ['REACTANT', 'CATALYST', 'SOLVENT', 'PRODUCT']: | |
for m in rxn_dict[key]: | |
if m in mol_role_map: | |
continue | |
if m in self.mol_property_map: | |
mol_role_map[m] = key | |
valid_mols = [] | |
for mol in mol_role_map: | |
assert mol in self.mol_property_map # this is garanteed by the above if statement | |
if 'abstract' not in self.mol_property_map[mol]: | |
continue | |
valid_mols.append(mol) # here the molecules should be in the R, C, S, P order. | |
if len(valid_mols) > 0: | |
self.valid_rxn_indices.append(rxn_id) | |
rxn_dict['valid_mols'] = valid_mols | |
rxn_dict['mol_role_map'] = mol_role_map | |
def _calculate_Pr(self): | |
total_weights = 0 | |
for rxn_dict in self.reaction_data: | |
rxn_weight = sum([1/self.mol_counter[mol] for mol in rxn_dict['valid_mols']]) | |
rxn_dict['weight'] = rxn_weight | |
total_weights += rxn_weight | |
for rxn_dict in self.reaction_data: | |
rxn_dict['weight'] = rxn_dict['weight'] / total_weights | |
def _calculate_Pir(self): | |
for rxn_dict in self.reaction_data: | |
mol_weight = {} | |
for mol in rxn_dict['valid_mols']: | |
mol_weight[mol] = 1/self.mol_counter[mol] | |
total_weight = sum(mol_weight.values()) | |
rxn_dict['mol_weight'] = {m:w/total_weight for m, w in mol_weight.items()} | |
def choose_mol(self, valid_mols, k=4, weights=None): | |
if k>=len(valid_mols): | |
sampled_indices = list(range(len(valid_mols))) | |
else: | |
sampled_indices = np.random.choice(len(valid_mols), k, replace=False, p=weights) | |
sampled_indices = list(sampled_indices) | |
sampled_indices = sorted(sampled_indices) | |
if random.random() < self.reverse_ratio: # reverse the indices with reverse_ratio chance. | |
sampled_indices.reverse() | |
sampled_mols = [valid_mols[i] for i in sampled_indices] | |
return sampled_mols | |
def sample_mol_batch(self, index=None, k=4): | |
if index is None: | |
index = self.sample_rxn_index(1)[0] | |
assert index < len(self.reaction_data) | |
rxn = self.reaction_data[index] | |
valid_mols, weights = zip(*rxn['mol_weight'].items()) | |
sampled_mols = self.choose_mol(valid_mols, k=k, weights=weights) | |
mol_property_batch = [] | |
for mol in sampled_mols: | |
mol_property = self.mol_property_map[mol] | |
mol_role = rxn['mol_role_map'][mol] | |
mol_property['role'] = mol_role | |
mol_property_batch.append(mol_property) | |
if 'rsmiles_map' in rxn: | |
rsmiles_map = random.choice(rxn['rsmiles_map']) | |
for mol_property in mol_property_batch: | |
canon_smiles = mol_property['canon_smiles'] | |
if canon_smiles in rsmiles_map: | |
mol_property['r_smiles'] = rsmiles_map[canon_smiles] | |
return mol_property_batch | |
def sample_rxn_index(self, num_samples): | |
indices = range(len(self.reaction_data)) | |
weights = [d['weight'] for d in self.reaction_data] | |
return np.random.choice(indices, num_samples, replace=False, p=weights) | |
def __call__(self, rxn_num=1000, k=4): | |
sampled_indices = self.sample_rxn_index(rxn_num) | |
sampled_batch = [self.sample_mol_batch(idx, k=k) for idx in sampled_indices] | |
return sampled_batch | |
def generate_batch_uniform_rxn(self, rxn_num=1000, k=4): | |
assert rxn_num <= len(self.valid_rxn_indices) | |
sampled_rxn_indices = random.sample(self.valid_rxn_indices, rxn_num) | |
sampled_batch = [] | |
for rxn_id in sampled_rxn_indices: | |
rxn = self.reaction_data[rxn_id] | |
sampled_mols = self.choose_mol(rxn['valid_mols'], k=k, weights=None) | |
mol_property_batch = [] | |
for mol in sampled_mols: | |
mol_property = self.mol_property_map[mol] | |
mol_role = rxn['mol_role_map'][mol] | |
mol_property['role'] = mol_role | |
mol_property_batch.append(mol_property) | |
sampled_batch.append(mol_property_batch) | |
return sampled_batch | |
def generate_batch_uniform_mol(self, rxn_num=1000, k=4): | |
valid_mols = list(self.mol_counter.elements()) | |
assert rxn_num*k <= len(valid_mols) | |
sampled_batch = [] | |
sampled_mol_ids = random.sample(range(len(valid_mols)), rxn_num*k) | |
for i in range(rxn_num): | |
sampled_batch.append([self.mol_property_map[valid_mols[mol_id]] for mol_id in sampled_mol_ids[i*k:(i+1)*k]]) | |
return sampled_batch | |
def generate_batch_single(self, rxn_num=1000): | |
valid_mols = list(self.mol_counter.elements()) | |
sampled_mols = random.sample(valid_mols, rxn_num) | |
total_valid_mols = [[self.mol_property_map[mol]] for mol in sampled_mols] | |
return total_valid_mols | |
# visaulize probability for molecules in caption dataset. | |
def visualize_mol_distribution(self): | |
prob_dict = {mol:0.0 for mol in self.mol_property_map.keys()} | |
N = len(prob_dict) | |
M = len(self.reaction_data) | |
assert N == len(self.mol_property_map) | |
print(f'Number of molecules in Caption Dataset: {N}') | |
print(f'Number of Reactions in Reaction Dataset: {M}') | |
# prob distribution for molecules | |
for rxn_dict in self.reaction_data: | |
for mol, weight in rxn_dict['mol_weight'].items(): | |
prob_dict[mol] += weight * rxn_dict['weight'] | |
# sum of prob_dict.values() should already be 1. | |
prob_values = np.array(list(prob_dict.values())) | |
prob_values *= N | |
# prob distribution for reactions | |
rxn_weights = np.array([d['weight'] for d in self.reaction_data]) | |
# sum of rxn_weights should already be 1. | |
rxn_weights *= M | |
return prob_values, rxn_weights | |
# visaulize the frequency for molecules in caption dataset. | |
def visualize_mol_frequency(self, rxn_num=1000, k=4, epochs=100): | |
sampled_mols_counter = Counter() | |
sampled_rxns_counter = Counter() | |
for _ in range(epochs): | |
rxn_indices = self.sample_rxn_index(rxn_num) | |
sampled_rxns_counter.update(rxn_indices) | |
for index in rxn_indices: | |
rxn = self.reaction_data[index] | |
if len(rxn['valid_mols']) ==0: | |
continue | |
valid_mols, weights = zip(*rxn['mol_weight'].items()) | |
mol_batch = self.choose_mol(valid_mols, k=k, weights=weights) | |
sampled_mols_counter.update(mol_batch) | |
sampled_mols_count = np.array([c for _, c in sorted(sampled_mols_counter.items())]) | |
sampled_rxns_count = np.array([c for _, c in sorted(sampled_rxns_counter.items())]) | |
return sampled_mols_count, sampled_rxns_count | |
def _randomly(self, func, *args, **kwargs): | |
# make fake weights and backup the weights | |
for rxn_dict in self.reaction_data: | |
rxn_dict['weight_bak'] = rxn_dict['weight'] | |
rxn_dict['weight'] = 1/len(self.reaction_data) | |
rxn_dict['mol_weight_bak'] = rxn_dict['mol_weight'] | |
rxn_dict['mol_weight'] = {m:1/len(rxn_dict['mol_weight']) for m in rxn_dict['mol_weight']} | |
# run the function | |
result = func(*args, **kwargs) | |
# weights recovery | |
for rxn_dict in self.reaction_data: | |
rxn_dict['weight'] = rxn_dict['weight_bak'] | |
del rxn_dict['weight_bak'] | |
rxn_dict['mol_weight'] = rxn_dict['mol_weight_bak'] | |
del rxn_dict['mol_weight_bak'] | |
return result | |