import torch from torch_geometric.data import Dataset import os from .context_gen import Reaction_Cluster import json from .data_utils import smiles2data, reformat_smiles from collections import defaultdict import random from data_provider.caption_dataset import PretrainCaptionDataset from data_provider.synthesis_dataset import SynthesisDataset def format_float_from_string(s): try: float_value = float(s) return f'{float_value:.2f}' except ValueError: return s class MoleculeAbstract(Dataset): def __init__(self, root, rxn_num=1000, rxn_batch_size=4, smi_max_len=128, prompt=None, disable_graph_cache=False, disable_graphs=False, context_style='weighted_rxn', use_caption_dataset=False, caption_batch_num=10000, synthesis_datasetpath=None, synthesis_batch_num=10000, reverse_ratio=0.5, enable_abstract=True, enable_property=True, smiles_type='default', mode='train' ): super(MoleculeAbstract, self).__init__(root) self.root = root self.rxn_num = rxn_num self.rxn_batch_size = rxn_batch_size self.smi_max_len = smi_max_len self.context_style = context_style self.tokenizer = None self.disable_graph_cache = disable_graph_cache self.disable_graphs = disable_graphs self.use_caption_dataset = use_caption_dataset self.smiles_type = smiles_type if use_caption_dataset: self.caption_dataset = PretrainCaptionDataset( os.path.join(root, '../caption_data'), smi_max_len=smi_max_len, use_graph=not self.disable_graphs, disable_graph_cache=disable_graph_cache, smiles_type=smiles_type, ) self.caption_batch_num = caption_batch_num self.use_synthesis_dataset = bool(synthesis_datasetpath) if self.use_synthesis_dataset: self.synthesis_dataset = SynthesisDataset( synthesis_datasetpath, 'train', smi_max_len, roundrobin_train=True, use_graph=not disable_graphs, disable_graph_cache=disable_graph_cache, smiles_type='default', ) self.synthesis_batch_num = synthesis_batch_num if not self.disable_graphs: self.mol_graph_map = torch.load(os.path.join(self.root, 'mol_graph_map.pt')) reaction_filename = 'reactions/reactions_test.json' if (mode=='test') else 'reactions/reactions.json' if smiles_type=='r_smiles': reaction_filename = 'reactions/reactions_wRSMILES.json' self.cluster = Reaction_Cluster(self.root, reaction_filename=reaction_filename, reverse_ratio=reverse_ratio) self.reload_data_list() self.abstract_max_len = 10240 self.property_max_len = 10240 self.enable_abstract = enable_abstract self.enable_property = enable_property def get(self, index): return self.__getitem__(index) def len(self): return len(self) def __len__(self): data_len = len(self.data_list) if self.use_caption_dataset: data_len += len(self.caption_index_list) if self.use_synthesis_dataset: data_len += len(self.synthesis_index_list) return data_len def reload_data_list(self): k = self.rxn_batch_size if self.context_style == 'weighted_rxn': self.data_list = self.cluster(self.rxn_num, k=k) elif self.context_style == 'uniform_rxn': self.data_list = self.cluster.generate_batch_uniform_rxn(self.rxn_num, k=k) elif self.context_style == 'uniform_mol': self.data_list = self.cluster.generate_batch_uniform_mol(self.rxn_num, k=k) elif self.context_style == 'single_mol': self.data_list = self.cluster.generate_batch_single(self.rxn_num) elif self.context_style == 'hybrid': self.data_list = self.cluster(self.rxn_num//2, k=k) self.data_list += self.cluster.generate_batch_uniform_mol(self.rxn_num//2, k=k) else: raise NotImplementedError if self.use_caption_dataset: assert self.caption_batch_num*k <= len(self.caption_dataset) caption_index_list = random.sample(range(len(self.caption_dataset)), self.caption_batch_num*k) self.caption_index_list = [caption_index_list[i*k:(i+1)*k] for i in range(self.caption_batch_num)] else: self.caption_index_list = [] if self.use_synthesis_dataset: if self.synthesis_dataset.roundrobin_train: self.synthesis_dataset.reload_data() assert self.synthesis_batch_num <= len(self.synthesis_dataset) self.synthesis_index_list = random.sample(range(len(self.synthesis_dataset)), self.synthesis_batch_num) else: self.synthesis_index_list = [] def make_prompt(self, mol_batch, smi_max_len=128): mol_prompt_list, text_prompt_list = [], [] last_role = None for mol_dict in mol_batch: smiles = mol_dict['canon_smiles'] if self.smiles_type=='r_smiles': if 'r_smiles' in mol_dict: smiles = mol_dict['r_smiles'] # else: # smiles = reformat_smiles(smiles, smiles_type='restricted') else: smiles = reformat_smiles(smiles, smiles_type=self.smiles_type) mol_prompt = f'[START_SMILES]{smiles[:smi_max_len]}[END_SMILES]. ' if 'role' in mol_dict: role = { 'REACTANT': 'Reactant', 'CATALYST': 'Catalyst', 'SOLVENT': 'Solvent', 'PRODUCT': 'Product', }[mol_dict['role']] if last_role != role: mol_prompt = f'{role}: {mol_prompt}' last_role = role text_prompt = self.make_abstract(mol_dict) mol_prompt_list.append(mol_prompt) text_prompt_list.append(text_prompt) return mol_prompt_list, text_prompt_list def make_abstract(self, mol_dict): prompt = '' if self.enable_abstract and 'abstract' in mol_dict: abstract_string = mol_dict['abstract'][:self.abstract_max_len] prompt += f'[Abstract] {abstract_string} ' if self.enable_property: property_string = '' property_dict = mol_dict['property'] if 'property' in mol_dict else {} for property_key in ['Experimental Properties', 'Computed Properties']: if not property_key in property_dict: continue for key, value in property_dict[property_key].items(): if isinstance(value, float): key_value_string = f'{key}: {value:.2f}; ' elif isinstance(value, str): float_value = format_float_from_string(value) key_value_string = f'{key}: {float_value}; ' else: key_value_string = f'{key}: {value}; ' if len(property_string+key_value_string) > self.property_max_len: break property_string += key_value_string if property_string: property_string = property_string[:self.property_max_len] prompt += f'[Properties] {property_string}. ' return prompt def get_caption_data(self, index): caption_index = self.caption_index_list[index] graph_list, mol_prompt_list, text_prompt_list = [], [], [] for idx in caption_index: graph_item, text, smiles_prompt = self.caption_dataset[idx] graph_list.append(graph_item) mol_prompt_list.append(smiles_prompt) text_prompt_list.append(text) return graph_list, mol_prompt_list, text_prompt_list def get_synthesis_data(self, index): synthesis_index = self.synthesis_index_list[index] _, graph_list, output_text, input_text = self.synthesis_dataset[synthesis_index] return graph_list, [input_text], [output_text] def __getitem__(self, index): if index < len(self.data_list): mol_batch = self.data_list[index] elif index < len(self.data_list)+len(self.caption_index_list): assert self.use_caption_dataset return self.get_caption_data(index-len(self.data_list)) else: assert self.use_synthesis_dataset return self.get_synthesis_data(index-(len(self.data_list)+len(self.caption_index_list))) graph_list = [] for mol_dict in mol_batch: smiles = mol_dict['canon_smiles'] if self.disable_graphs: graph_item = None else: if self.disable_graph_cache: graph_item = smiles2data(smiles) else: assert smiles in self.mol_graph_map graph_item = self.mol_graph_map[smiles] graph_list.append(graph_item) mol_prompt_list, text_prompt_list = self.make_prompt(mol_batch, smi_max_len=self.smi_max_len) return graph_list, mol_prompt_list, text_prompt_list