ReactXT / data_provider /reaction_action_dataset.py
SyrWin
init
95f97c5
import torch
from torch_geometric.data import Dataset
import os
import random
import json
from .data_utils import smiles2data, reformat_smiles
class ActionDataset(Dataset):
def __init__(self, root, mode, smi_max_len, use_graph=True, disable_graph_cache=False, predict_rxn_condition=False, smiles_type='default'):
super(ActionDataset, self).__init__(root)
self.root = root
self.smi_max_len = smi_max_len
self.tokenizer = None
self.use_graph = use_graph
self.disable_graph_cache = disable_graph_cache
self.predict_rxn_condition = predict_rxn_condition
self.smiles_type = smiles_type
with open(os.path.join(self.root, f'{mode}.json'), encoding='utf-8') as f:
self.data_list = json.load(f)
if self.use_graph:
self.mol_graph_map = torch.load(os.path.join(self.root, 'mol_graph_map.pt'))
# self.data_list = self.data_list[:100]
def get(self, index):
return self.__getitem__(index)
def len(self):
return len(self)
def __len__(self):
return len(self.data_list)
def make_prompt(self, param_dict, smi_max_len=128, predict_rxn_condition=False):
action_sequence = param_dict['actions']
smiles_list = []
prompt = ''
prompt += 'Reactants: '
smiles_wrapper = lambda x: reformat_smiles(x, smiles_type=self.smiles_type)[:smi_max_len]
for smi in param_dict['REACTANT']:
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
smiles_list.append(smi)
prompt += 'Product: '
for smi in param_dict['PRODUCT']:
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
smiles_list.append(smi)
if param_dict['CATALYST']:
prompt += 'Catalysts: '
for smi in param_dict['CATALYST']:
if smi in param_dict["extracted_molecules"]:
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
else:
prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
smiles_list.append(smi)
if param_dict['SOLVENT']:
prompt += 'Solvents: '
for smi in param_dict['SOLVENT']:
if smi in param_dict["extracted_molecules"]:
prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
else:
prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
smiles_list.append(smi)
if predict_rxn_condition:
for value, token in param_dict['extracted_duration'].items():
action_sequence = action_sequence.replace(token, value)
for value, token in param_dict['extracted_temperature'].items():
action_sequence = action_sequence.replace(token, value)
else:
prompt += 'Temperatures: '
for value, token in param_dict['extracted_temperature'].items():
prompt += f'{token}: {value} '
prompt += 'Durations: '
for value, token in param_dict['extracted_duration'].items():
prompt += f'{token}: {value} '
prompt += 'Action Squence: '
return prompt, smiles_list, action_sequence
def __getitem__(self, index):
rxn_dict = self.data_list[index]
rxn_id = rxn_dict['index']
input_text, smiles_list, output_text = self.make_prompt(rxn_dict, self.smi_max_len, self.predict_rxn_condition)
output_text = output_text.strip() + '\n'
graph_list = []
if self.use_graph:
for smiles in smiles_list:
if self.disable_graph_cache:
graph_item = smiles2data(smiles)
else:
assert smiles in self.mol_graph_map
graph_item = self.mol_graph_map[smiles]
graph_list.append(graph_item)
return rxn_id, graph_list, output_text, input_text