Spaces:
Sleeping
Sleeping
import os | |
import sys | |
import torch | |
from torch.utils.data import Dataset | |
import json | |
import numpy as np | |
from torch.utils.data.dataloader import default_collate | |
import time | |
class QADataset(Dataset): | |
# def __init__(self, pdb_root, seq_root, ann_paths, dataset_description, chain="A"): | |
def __init__(self, pdb_root, seq_root, ann_paths, chain="A"): | |
""" | |
pdb_root (string): Root directory of protein pdb embeddings (e.g. xyz/pdb/) | |
seq_root (string): Root directory of sequences embeddings (e.g. xyz/seq/) | |
ann_root (string): directory to store the annotation file | |
dataset_description (string): json file that describes what data are used for training/testing | |
""" | |
# data_describe = json.load(open(dataset_description, "r")) | |
# train_set = set(data_describe["train"]) | |
self.pdb_root = pdb_root | |
self.seq_root = seq_root | |
self.qa = json.load(open(ann_paths, "r")) | |
self.qa_keys = list(self.qa.keys()) | |
keep = {} | |
# for i in range(0, len(self.qa_keys)): | |
# if (self.qa_keys[i] in train_set): | |
# keep[self.qa_keys[i]] = self.qa[self.qa_keys[i]] | |
# self.qa = keep | |
self.qa_keys = list(self.qa.keys()) # update qa keys to reflect what was saved after keep | |
self.questions = [] | |
for key in self.qa_keys: | |
for j in range(0, len(self.qa[key])): | |
self.questions.append((self.qa[key][j], key)) | |
self.chain = chain | |
def __len__(self): | |
return len(self.questions) | |
def __getitem__(self, index): | |
qa = self.questions[index] | |
pdb_id = qa[1] | |
pdb_embedding = '{}.pt'.format(pdb_id) | |
pdb_embedding_path = os.path.join(self.pdb_root, pdb_embedding) | |
pdb_embedding = torch.load( | |
pdb_embedding_path, map_location=torch.device('cpu')) | |
# pdb_embedding_path, map_location=torch.device('cuda')) | |
pdb_embedding.requires_grad = False | |
seq_embedding = '{}.pt'.format(pdb_id) | |
seq_embedding_path = os.path.join(self.seq_root, seq_embedding) | |
seq_embedding = torch.load( | |
seq_embedding_path, map_location=torch.device('cpu')) | |
# seq_embedding_path, map_location=torch.device('cuda')) | |
seq_embedding.requires_grad = False | |
return { | |
"q_input": str(qa[0]['Q']), | |
"a_input": str(qa[0]['A']), | |
"pdb_encoder_out": pdb_embedding, | |
"seq_encoder_out": seq_embedding, | |
"chain": self.chain, | |
"pdb_id": pdb_id | |
} | |
def collater(self, samples): | |
max_len_pdb_dim0 = max(pdb_json["pdb_encoder_out"].shape[0] for pdb_json in samples) | |
max_len_seq_dim0 = max(pdb_json["seq_encoder_out"].shape[0] for pdb_json in samples) | |
for pdb_json in samples: | |
pdb_embeddings = pdb_json["pdb_encoder_out"] | |
pad_pdb = ((0, max_len_pdb_dim0 - pdb_embeddings.shape[0]), (0, 0), (0, 0)) | |
pdb_json["pdb_encoder_out"] = torch.tensor(np.pad(pdb_embeddings, pad_pdb, mode='constant')) | |
seq_embeddings = pdb_json["seq_encoder_out"] | |
pad_seq = ((0, max_len_seq_dim0 - seq_embeddings.shape[0]), (0, 0), (0, 0)) | |
pdb_json["seq_encoder_out"] = torch.tensor(np.pad(seq_embeddings, pad_seq, mode='constant')) | |
return default_collate(samples) | |