Spaces:
Runtime error
Runtime error
import torch | |
from torch_geometric.data import Dataset | |
import os | |
from torch_geometric.data import InMemoryDataset | |
import random | |
import json | |
from .data_utils import reformat_smiles | |
class ChEBI_dataset(Dataset): | |
def __init__(self, root, mode, smi_max_len=128, use_graph=True, disable_graph_cache=False, smiles_type='default'): | |
super(ChEBI_dataset, self).__init__(root) | |
self.root = root | |
self.file_path = os.path.join(root, f'{mode}.txt') | |
self.smi_max_len = smi_max_len | |
self.tokenizer = None | |
self.use_graph = use_graph | |
self.smiles_type = smiles_type | |
if self.use_graph: | |
self.idx_graph_map = torch.load(os.path.join(root, 'cid_graph_map.pt')) | |
with open(self.file_path) as f: | |
lines = f.readlines() | |
self.data = [line.split('\t', maxsplit=2) for line in lines[1:]] | |
def get(self, index): | |
return self.__getitem__(index) | |
def len(self): | |
return len(self) | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, index): | |
cid, smiles, text = self.data[index] | |
smiles = reformat_smiles(smiles, smiles_type=self.smiles_type) | |
smiles_prompt = f'[START_I_SMILES]{smiles[:self.smi_max_len]}[END_I_SMILES]. ' | |
text = text.strip() + '\n' | |
if self.use_graph: | |
graph_list = [self.idx_graph_map[cid]] | |
return index, graph_list, text, smiles_prompt | |