Spaces:
Sleeping
Sleeping
| import os | |
| import glob | |
| import torch | |
| import random | |
| import selfies as sf | |
| from rdkit import Chem | |
| from datasets import load_dataset | |
| from transformers import T5EncoderModel | |
| from torch.utils.data import DistributedSampler, DataLoader, Dataset | |
| def get_dataloader(dataset, batchsize, rank, world_size): | |
| sampler = DistributedSampler( | |
| dataset, num_replicas=world_size, rank=rank, shuffle=True | |
| ) | |
| def collate(batch): | |
| selfies_ids = [i["selfies_ids"] for i in batch] | |
| caption_state = [i["caption_state"] for i in batch] | |
| caption_mask = [i["caption_mask"] for i in batch] | |
| corrupted_selfies_ids = [i["corrupted_selfies_ids"] for i in batch] | |
| return ( | |
| torch.concat(selfies_ids, dim=0), | |
| torch.concat(caption_state, dim=0), | |
| torch.concat(caption_mask, dim=0), | |
| torch.concat(corrupted_selfies_ids, dim=0), | |
| ) | |
| dataloader = DataLoader( | |
| dataset, | |
| batch_size=batchsize, | |
| shuffle=False, | |
| collate_fn=collate, | |
| sampler=sampler, | |
| ) | |
| def cycle(): | |
| ec = 0 | |
| while True: | |
| dataloader.sampler.set_epoch(ec) | |
| for i in dataloader: | |
| yield i | |
| ec += 1 | |
| return iter(cycle()) | |
| class Lang2molDataset_train(Dataset): | |
| def __init__( | |
| self, | |
| dir, | |
| tokenizer, | |
| split, | |
| dataset_name, | |
| pre=None, | |
| prob=0, | |
| load_state=True, | |
| corrupt_prob=0.4, | |
| token_max_length=256, | |
| ): | |
| super().__init__() | |
| self.dir = dir | |
| self.tokenizer = tokenizer | |
| self.split = split | |
| self.pre = pre | |
| self.prob = prob | |
| self.corrupt_prob = corrupt_prob | |
| self.token_max_length = token_max_length | |
| self.dataset_name = dataset_name | |
| self.ori_data = self.create_data() | |
| self.load_state = load_state | |
| self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") | |
| self.model.to("cuda") | |
| self.model.eval() | |
| def create_data(self): | |
| try: | |
| dataset = load_dataset( | |
| self.dataset_name, | |
| token=True, | |
| split=self.split, | |
| ).sort("id") | |
| except: | |
| dataset = load_dataset( | |
| self.dataset_name, | |
| use_auth_token=True, | |
| split=self.split, | |
| ).sort("id") | |
| return [ | |
| (int(sample_id), sample_selfies, sample_caption, sample_smiles) | |
| for (sample_id, sample_selfies, sample_caption, sample_smiles) in zip( | |
| dataset["id"], | |
| dataset["selfies"], | |
| dataset["caption"], | |
| dataset["smiles"], | |
| ) | |
| ] | |
| def __len__(self): | |
| return len(self.ori_data) | |
| def permute(self, selfies): | |
| if random.random() < self.prob: | |
| return changeorder(selfies, shuffle=True) | |
| else: | |
| return selfies | |
| def __getitem__(self, idx): | |
| data = self.ori_data[idx] | |
| sample = { | |
| "id": data[0], | |
| "selfies": self.permute(data[1]), | |
| "caption": data[2], | |
| "smiles": data[3], | |
| } | |
| # Molecules | |
| output_molecule = self.tokenizer( | |
| sample["selfies"], | |
| max_length=self.token_max_length, | |
| truncation=True, | |
| padding="max_length", | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| sample["selfies_ids"] = output_molecule["input_ids"] | |
| sample["corrupted_selfies_ids"] = sample["selfies_ids"] | |
| # Captions | |
| output_caption = self.tokenizer( | |
| sample["caption"], | |
| max_length=self.token_max_length, | |
| truncation=True, | |
| padding="max_length", | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| sample["caption_state"] = self.model( | |
| input_ids=output_caption["input_ids"].to("cuda"), | |
| attention_mask=output_caption["attention_mask"].to("cuda"), | |
| ).last_hidden_state | |
| sample["caption_mask"] = output_caption["attention_mask"] | |
| return sample | |
| class Lang2molDataset_eval(Dataset): | |
| def __init__( | |
| self, | |
| dir, | |
| tokenizer, | |
| split, | |
| dataset_name, | |
| pre=None, | |
| prob=0, | |
| load_state=True, | |
| corrupt_prob=0.4, | |
| token_max_length=512, | |
| ): | |
| super().__init__() | |
| self.dir = dir | |
| self.tokenizer = tokenizer | |
| self.split = split | |
| self.pre = pre | |
| self.prob = prob | |
| self.corrupt_prob = corrupt_prob | |
| self.token_max_length = token_max_length | |
| self.dataset_name = dataset_name | |
| self.ori_data = self.create_data() | |
| self.load_state = load_state | |
| self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") | |
| self.model.to("cuda") | |
| self.model.eval() | |
| def create_data(self): | |
| try: | |
| dataset = load_dataset( | |
| self.dataset_name, | |
| token=True, | |
| split=self.split, | |
| ).sort("id") | |
| except: | |
| dataset = load_dataset( | |
| self.dataset_name, | |
| use_auth_token=True, | |
| split=self.split, | |
| ).sort("id") | |
| return [ | |
| (int(sample_id), sample_selfies, sample_caption, sample_smiles) | |
| for (sample_id, sample_selfies, sample_caption, sample_smiles) in zip( | |
| dataset["id"], | |
| dataset["selfies"], | |
| dataset["caption"], | |
| dataset["smiles"], | |
| ) | |
| ] | |
| def __len__(self): | |
| return len(self.ori_data) | |
| def permute(self, selfies): | |
| if random.random() < self.prob: | |
| return changeorder(selfies, shuffle=True) | |
| else: | |
| return selfies | |
| def __getitem__(self, idx): | |
| data = self.ori_data[idx] | |
| sample = { | |
| "id": data[0], | |
| "selfies": self.permute(data[1]), | |
| "caption": data[2], | |
| "smiles": data[3], | |
| } | |
| output_caption = self.tokenizer( | |
| sample["caption"], | |
| max_length=self.token_max_length, | |
| truncation=True, | |
| padding="max_length", | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| sample["caption_state"] = self.model( | |
| input_ids=output_caption["input_ids"].to("cuda"), | |
| attention_mask=output_caption["attention_mask"].to("cuda"), | |
| ).last_hidden_state | |
| sample["caption_mask"] = output_caption["attention_mask"] | |
| return sample | |
| class Lang2molDataset_submission(Dataset): | |
| def __init__( | |
| self, | |
| dir, | |
| tokenizer, | |
| split, | |
| dataset_name, | |
| pre=None, | |
| prob=0, | |
| load_state=True, | |
| corrupt_prob=0.4, | |
| token_max_length=256, | |
| ): | |
| super().__init__() | |
| self.dir = dir | |
| self.tokenizer = tokenizer | |
| self.split = split | |
| self.pre = pre | |
| self.prob = prob | |
| self.corrupt_prob = corrupt_prob | |
| self.token_max_length = token_max_length | |
| self.dataset_name = dataset_name | |
| self.ori_data = self.create_data() | |
| self.load_state = load_state | |
| self.model = T5EncoderModel.from_pretrained("QizhiPei/biot5-base-text2mol") | |
| self.model.to("cuda") | |
| self.model.eval() | |
| def create_data(self): | |
| try: | |
| dataset = load_dataset( | |
| self.dataset_name, | |
| token=True, | |
| split=self.split, | |
| ) | |
| except: | |
| dataset = load_dataset( | |
| self.dataset_name, | |
| use_auth_token=True, | |
| split=self.split, | |
| ) | |
| return [sample_caption for sample_caption in dataset["caption"]] | |
| def __len__(self): | |
| return len(self.ori_data) | |
| def permute(self, selfies): | |
| if random.random() < self.prob: | |
| return changeorder(selfies, shuffle=True) | |
| else: | |
| return selfies | |
| def __getitem__(self, idx): | |
| sample = {"caption": self.ori_data[idx]} | |
| # Captions | |
| output_caption = self.tokenizer( | |
| sample["caption"], | |
| max_length=self.token_max_length, | |
| truncation=True, | |
| padding="max_length", | |
| add_special_tokens=True, | |
| return_tensors="pt", | |
| return_attention_mask=True, | |
| ) | |
| sample["caption_state"] = self.model( | |
| input_ids=output_caption["input_ids"].to("cuda"), | |
| attention_mask=output_caption["attention_mask"].to("cuda"), | |
| ).last_hidden_state | |
| sample["caption_mask"] = output_caption["attention_mask"] | |
| return sample | |
| def changeorder(selfies, shuffle): | |
| smiles = sf.encoder(selfies) | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return selfies | |
| Chem.Kekulize(mol) | |
| atom_indices = [atom.GetIdx() for atom in mol.GetAtoms()] | |
| if shuffle: | |
| random.shuffle(atom_indices) | |
| reordered_mol = Chem.RenumberAtoms(mol, atom_indices) | |
| new_smiles = Chem.MolToSmiles(reordered_mol, kekuleSmiles=True) | |
| new_selfies = sf.decoder(new_smiles) | |
| return new_selfies | |