Spaces:
Runtime error
Runtime error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
import torch | |
from pytorch_lightning import LightningDataModule | |
import torch_geometric | |
# from torch_geometric.loader import DataLoader | |
from torch.utils.data import DataLoader | |
from torch_geometric.loader.dataloader import Collater | |
from data_provider.molecule_abstract_dataset import MoleculeAbstract | |
import re | |
from transformers import BatchEncoding | |
# we split individual characters inside special tokens like [START_DNA] | |
CUSTOM_SEQ_RE = re.compile(r"(\[START_(DNA|SMILES|I_SMILES|AMINO)])(.*?)(\[END_\2])") | |
# token added to implement a custom sequence tokenization. This token is added at | |
# corpus cleaning step and removed in pretokenization. The digits are added to increase the chance | |
# that they do not occur in the corpus. The digits are escaped so that the token does not appear | |
# literally in the source code in case we ever include it in the training data. | |
SPLIT_MARKER = f"SPL{1}T-TH{1}S-Pl3A5E" | |
def _insert_split_marker(m: re.Match): | |
""" | |
Applies split marker based on a regex match of special tokens such as | |
[START_DNA]. | |
Parameters | |
---------- | |
n : str | |
Input text to split | |
Returns | |
---------- | |
str - the text with the split token added | |
""" | |
start_token, _, sequence, end_token = m.groups() | |
sequence = re.sub(r"(.)", fr"{SPLIT_MARKER}\1", sequence, flags=re.DOTALL) | |
return f"{start_token}{sequence}{SPLIT_MARKER}{end_token}" | |
def smiles_handler(text, mol_ph, is_gal=True): | |
smiles_list = [] | |
for match in CUSTOM_SEQ_RE.finditer(text): | |
smiles = match.group(3) | |
smiles_list.append(smiles) | |
if is_gal: | |
text = CUSTOM_SEQ_RE.sub(r'\1\3\4%s' % (mol_ph), text) | |
text = escape_custom_split_sequence(text) | |
return text, smiles_list | |
else: | |
text = CUSTOM_SEQ_RE.sub(r'\3%s' % (mol_ph), text) | |
return text, smiles_list | |
def escape_custom_split_sequence(text): | |
""" | |
Applies custom splitting to the text for GALILEO's tokenization | |
Parameters | |
---------- | |
text : str | |
Input text to split | |
Returns | |
---------- | |
str - the text with the split token added | |
""" | |
return CUSTOM_SEQ_RE.sub(_insert_split_marker, text) | |
def tokenize_and_merge_batched_qa_pairs(tokenizer, qa_pairs_list, max_length): | |
tokenized_batches = { | |
'input_ids': [], | |
'attention_mask': [] | |
} | |
for qa_pairs in qa_pairs_list: | |
max_length_per_qa = max_length // len(qa_pairs) | |
batch_input_ids = [] | |
batch_attention_mask = [] | |
for qa in qa_pairs: | |
# here qa should be string | |
tokens = tokenizer(qa, | |
truncation=True, | |
padding=False, | |
add_special_tokens=False, | |
max_length=max_length_per_qa, | |
return_tensors='pt', | |
return_attention_mask=True) | |
batch_input_ids.extend(tokens['input_ids'].squeeze().tolist()) | |
batch_attention_mask.extend(tokens['attention_mask'].squeeze().tolist()) | |
# Pad the batch to max_length | |
padding_length = max_length - len(batch_input_ids) | |
batch_input_ids.extend([tokenizer.pad_token_id] * padding_length) | |
batch_attention_mask.extend([0] * padding_length) | |
tokenized_batches['input_ids'].append(torch.tensor(batch_input_ids).unsqueeze(0)) | |
tokenized_batches['attention_mask'].append(torch.tensor(batch_attention_mask).unsqueeze(0)) | |
tokenized_batches['input_ids'] = torch.cat(tokenized_batches['input_ids'], dim=0) | |
tokenized_batches['attention_mask'] = torch.cat(tokenized_batches['attention_mask'], dim=0) | |
tokenized_batch = BatchEncoding(data=tokenized_batches, tensor_type='pt') | |
return tokenized_batch | |
class TrainCollater: | |
def __init__(self, tokenizer, text_max_len, mol_ph, mol_token_id, is_gal=True, disable_graphs=False): | |
self.text_max_len = text_max_len | |
self.tokenizer = tokenizer | |
self.collater = Collater([], []) | |
self.mol_ph = mol_ph | |
self.mol_token_id = mol_token_id | |
self.is_gal = is_gal | |
self.disable_graphs = disable_graphs | |
def __call__(self, batch): | |
graphs, mol_prompt, text_prompt = zip(*batch) | |
if not self.disable_graphs: | |
graphs = [graph for graph_batch in graphs for graph in graph_batch] | |
graphs = self.collater(graphs) | |
qa_pairs = [] | |
for mol_batch, text_batch in zip(mol_prompt, text_prompt): | |
qa_list = [] | |
for mol_prompt, text_prompt in zip(mol_batch, text_batch): | |
smiles_prompt = smiles_handler(mol_prompt, self.mol_ph, self.is_gal)[0] | |
qa_list.append(f'{smiles_prompt} {text_prompt}') | |
qa_pairs.append(qa_list) | |
self.tokenizer.padding_side = 'right' | |
qa_batch = tokenize_and_merge_batched_qa_pairs(self.tokenizer, qa_pairs, self.text_max_len) | |
is_mol_token = qa_batch.input_ids == self.mol_token_id | |
qa_batch['is_mol_token'] = is_mol_token | |
return graphs, qa_batch | |
class InferenceCollater: | |
def __init__(self, tokenizer, text_max_len, mol_ph, mol_token_id, is_gal=True, disable_graphs=False, last_only=False): | |
self.text_max_len = text_max_len | |
self.tokenizer = tokenizer | |
self.collater = Collater([], []) | |
self.mol_ph = mol_ph | |
self.mol_token_id = mol_token_id | |
self.is_gal = is_gal | |
self.disable_graphs = disable_graphs | |
self.last_only = last_only | |
def __call__(self, batch): | |
graphs, mol_prompt, text_prompt = zip(*batch) | |
rxn_ids = [0 for i in range(len(mol_prompt))] | |
if self.last_only: | |
mol_prompt = [[mol_batch[-1]] for mol_batch in mol_prompt] | |
text_prompt = [[text_batch[-1]] for text_batch in text_prompt] | |
graphs = [[graph_batch[-1]] for graph_batch in graphs] | |
if not self.disable_graphs: | |
graphs = [graph for graph_batch in graphs for graph in graph_batch] | |
graphs = self.collater(graphs) | |
input_text, output_text = [], [] | |
for mol_batch, text_batch in zip(mol_prompt, text_prompt): | |
qa_list = [] | |
for mol_prompt, text_prompt in list(zip(mol_batch, text_batch))[:-1]: | |
smiles_prompt = smiles_handler(mol_prompt, self.mol_ph, self.is_gal)[0] | |
qa_list.append(f'{smiles_prompt} {text_prompt}') | |
qa_list.append(f'{smiles_handler(mol_batch[-1], self.mol_ph, self.is_gal)[0]} ') | |
output_text.append(text_batch[-1]) | |
input_text.append(qa_list) | |
self.tokenizer.padding_side = 'right' | |
input_batch = tokenize_and_merge_batched_qa_pairs(self.tokenizer, input_text, self.text_max_len) | |
is_mol_token = input_batch.input_ids == self.mol_token_id | |
input_batch['is_mol_token'] = is_mol_token | |
return rxn_ids, graphs, input_batch, output_text, input_text | |
class PretrainDM(LightningDataModule): | |
def __init__( | |
self, | |
num_workers: int = 0, | |
batch_size: int = 256, | |
root: str = 'data/', | |
text_max_len: int = 128, | |
rxn_max_len: int = 128, | |
smi_max_len: int = 128, | |
tokenizer=None, | |
args=None, | |
): | |
super().__init__() | |
self.args = args | |
self.batch_size = batch_size | |
self.inference_batch_size = args.inference_batch_size | |
self.num_workers = num_workers | |
self.text_max_len = text_max_len | |
self.rxn_max_len = rxn_max_len | |
self.pretrain_dataset = MoleculeAbstract( | |
root, | |
rxn_num=args.pretrain_rxn_num, | |
rxn_batch_size=args.rxn_batch_size, | |
smi_max_len=smi_max_len, | |
disable_graph_cache=args.disable_graph_cache, | |
context_style=args.context_style, | |
disable_graphs=args.disable_graphs, | |
use_caption_dataset=args.pretrain_use_caption, | |
caption_batch_num=args.caption_batch_num, | |
synthesis_datasetpath=args.pretrain_synthesis_path, | |
synthesis_batch_num=args.synthesis_batch_num, | |
reverse_ratio=args.reverse_ratio, | |
enable_abstract=not args.disable_abstract, | |
enable_property=not args.disable_property, | |
smiles_type=args.smiles_type, | |
) | |
self.test_dataset = MoleculeAbstract( | |
root, | |
rxn_num=args.pretrain_rxn_num, | |
rxn_batch_size=args.rxn_batch_size, | |
smi_max_len=smi_max_len, | |
disable_graph_cache=args.disable_graph_cache, | |
context_style=args.context_style, | |
disable_graphs=args.disable_graphs, | |
use_caption_dataset=args.pretrain_use_caption, | |
caption_batch_num=args.caption_batch_num, | |
reverse_ratio=args.reverse_ratio, | |
enable_abstract=not args.disable_abstract, | |
enable_property=not args.disable_property, | |
smiles_type=args.smiles_type, | |
mode='test', | |
) | |
self.init_tokenizer(tokenizer) | |
self.mol_ph_token = '<mol>' * self.args.num_query_token | |
self.is_gal = args.opt_model.find('galactica') >= 0 | |
self.disable_graphs = args.disable_graphs | |
self.last_only = args.pretrain_eval_last_only | |
def init_tokenizer(self, tokenizer): | |
self.tokenizer = tokenizer | |
self.pretrain_dataset.tokenizer = tokenizer | |
self.test_dataset.tokenizer = tokenizer | |
self.mol_token_id = self.tokenizer.mol_token_id | |
# self.tokenizer.mol_token_id = tokenizer("<mol>", add_special_tokens=False).input_ids[0] | |
def train_dataloader(self): | |
self.pretrain_dataset.reload_data_list() | |
loader = DataLoader( | |
self.pretrain_dataset, | |
batch_size=self.batch_size, | |
shuffle=True, | |
num_workers=self.num_workers, | |
pin_memory=False, | |
drop_last=True, | |
persistent_workers=True, | |
collate_fn=TrainCollater( | |
tokenizer=self.tokenizer, | |
text_max_len=self.text_max_len, | |
mol_ph=self.mol_ph_token, | |
mol_token_id=self.mol_token_id, | |
is_gal=self.is_gal, | |
disable_graphs=self.disable_graphs, | |
), | |
) | |
return loader | |
def val_dataloader(self): | |
test_loader = DataLoader( | |
self.test_dataset, | |
batch_size=self.inference_batch_size, | |
shuffle=False, | |
num_workers=self.num_workers, | |
pin_memory=False, | |
drop_last=False, | |
persistent_workers=True, | |
collate_fn=InferenceCollater( | |
tokenizer=self.tokenizer, | |
text_max_len=self.text_max_len, | |
mol_ph=self.mol_ph_token, | |
mol_token_id=self.mol_token_id, | |
is_gal=self.is_gal, | |
disable_graphs=self.disable_graphs, | |
last_only=self.last_only, | |
), | |
) | |
return [test_loader] | |
def add_model_specific_args(parent_parser): | |
parser = parent_parser.add_argument_group("Data module") | |
parser.add_argument('--num_workers', type=int, default=2) | |
parser.add_argument('--batch_size', type=int, default=4) | |
parser.add_argument('--inference_batch_size', type=int, default=4) | |
parser.add_argument('--use_smiles', action='store_true', default=False) | |
parser.add_argument('--root', type=str, default='data/action_data') | |
parser.add_argument('--context_style', type=str, default='weighted_rxn', choices=['weighted_rxn', 'uniform_rxn', 'uniform_mol', 'single_mol', 'hybrid']) | |
parser.add_argument('--rxn_max_len', type=int, default=512) | |
parser.add_argument('--text_max_len', type=int, default=512) | |
parser.add_argument('--smi_max_len', type=int, default=128) | |
parser.add_argument('--pretrain_rxn_num', type=int, default=50000) | |
parser.add_argument('--reverse_ratio', type=float, default=0.5, help='ratio of reversed reactions (retro reactions)') | |
parser.add_argument('--disable_abstract', action='store_true', default=False) | |
parser.add_argument('--disable_property', action='store_true', default=False) | |
parser.add_argument('--pretrain_use_caption', action='store_true', default=False) | |
parser.add_argument('--caption_batch_num', type=int, default=5000) | |
parser.add_argument('--pretrain_synthesis_path', type=str, default=None) | |
parser.add_argument('--synthesis_batch_num', type=int, default=5000) | |
parser.add_argument('--rxn_batch_size', type=int, default=4) | |
parser.add_argument('--roundrobin_train', action='store_true', default=False) | |
parser.add_argument('--test_subset', type=int, default=-1) | |
parser.add_argument('--pretrain_eval_last_only', default=False, action='store_true') | |
parser.add_argument('--prompt', type=str, default=None) | |
return parent_parser | |