File size: 4,212 Bytes
95f97c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
from torch_geometric.data import Dataset
import os
import random
import json
from .data_utils import smiles2data, reformat_smiles

class ActionDataset(Dataset):
    def __init__(self, root, mode, smi_max_len, use_graph=True, disable_graph_cache=False, predict_rxn_condition=False, smiles_type='default'):
        super(ActionDataset, self).__init__(root)
        self.root = root
        self.smi_max_len = smi_max_len
        self.tokenizer = None
        self.use_graph = use_graph
        self.disable_graph_cache = disable_graph_cache
        self.predict_rxn_condition = predict_rxn_condition
        self.smiles_type = smiles_type

        with open(os.path.join(self.root, f'{mode}.json'), encoding='utf-8') as f:
            self.data_list = json.load(f)
        if self.use_graph:
            self.mol_graph_map = torch.load(os.path.join(self.root, 'mol_graph_map.pt'))
        # self.data_list = self.data_list[:100]

    def get(self, index):
        return self.__getitem__(index)

    def len(self):
        return len(self)

    def __len__(self):
        return len(self.data_list)

    def make_prompt(self, param_dict, smi_max_len=128, predict_rxn_condition=False):
        action_sequence = param_dict['actions']
        smiles_list = []
        prompt = ''
        prompt += 'Reactants: '
        smiles_wrapper = lambda x: reformat_smiles(x, smiles_type=self.smiles_type)[:smi_max_len]
        for smi in param_dict['REACTANT']:
            prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
            smiles_list.append(smi)
            
        prompt += 'Product: '
        for smi in param_dict['PRODUCT']:
            prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
            smiles_list.append(smi)

        if param_dict['CATALYST']:
            prompt += 'Catalysts: '
            for smi in param_dict['CATALYST']:
                if smi in param_dict["extracted_molecules"]:
                    prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
                else:
                    prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
                smiles_list.append(smi)

        if param_dict['SOLVENT']:
            prompt += 'Solvents: '
            for smi in param_dict['SOLVENT']:
                if smi in param_dict["extracted_molecules"]:
                    prompt += f'{param_dict["extracted_molecules"][smi]}: [START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
                else:
                    prompt += f'[START_SMILES]{smiles_wrapper(smi)}[END_SMILES] '
                smiles_list.append(smi)

        if predict_rxn_condition:
            for value, token in param_dict['extracted_duration'].items():
                action_sequence = action_sequence.replace(token, value)
            for value, token in param_dict['extracted_temperature'].items():
                action_sequence = action_sequence.replace(token, value)
        else:
            prompt += 'Temperatures: '
            for value, token in param_dict['extracted_temperature'].items():
                prompt += f'{token}: {value} '

            prompt += 'Durations: '
            for value, token in param_dict['extracted_duration'].items():
                prompt += f'{token}: {value} '

        prompt += 'Action Squence: '
        return prompt, smiles_list, action_sequence

    def __getitem__(self, index):
        rxn_dict = self.data_list[index]
        rxn_id = rxn_dict['index']
        input_text, smiles_list, output_text = self.make_prompt(rxn_dict, self.smi_max_len, self.predict_rxn_condition)
        output_text = output_text.strip() + '\n'

        graph_list = []
        if self.use_graph:
            for smiles in smiles_list:
                if self.disable_graph_cache:
                    graph_item = smiles2data(smiles)
                else:
                    assert smiles in self.mol_graph_map
                    graph_item = self.mol_graph_map[smiles]
                graph_list.append(graph_item)
        return rxn_id, graph_list, output_text, input_text