File size: 3,340 Bytes
85ab89d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import sys

import torch
from torch.utils.data import Dataset
import json
import numpy as np
from torch.utils.data.dataloader import default_collate

import time

class QADataset(Dataset):
    # def __init__(self, pdb_root, seq_root, ann_paths, dataset_description, chain="A"):
    def __init__(self, pdb_root, seq_root, ann_paths, chain="A"):
        """
        pdb_root (string): Root directory of protein pdb embeddings (e.g. xyz/pdb/)
        seq_root (string): Root directory of sequences embeddings (e.g. xyz/seq/)
        ann_root (string): directory to store the annotation file
        dataset_description (string): json file that describes what data are used for training/testing
		"""
        # data_describe = json.load(open(dataset_description, "r"))
        # train_set = set(data_describe["train"])
        self.pdb_root = pdb_root
        self.seq_root = seq_root
        self.qa = json.load(open(ann_paths, "r"))
        self.qa_keys = list(self.qa.keys())
        keep = {}
        # for i in range(0, len(self.qa_keys)):
        #     if (self.qa_keys[i] in train_set):
        #         keep[self.qa_keys[i]] = self.qa[self.qa_keys[i]]
        # self.qa = keep

        self.qa_keys = list(self.qa.keys()) # update qa keys to reflect what was saved after keep
        self.questions = []
        for key in self.qa_keys:
            for j in range(0, len(self.qa[key])):
                self.questions.append((self.qa[key][j], key))
        self.chain = chain

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, index):
        qa = self.questions[index]
        pdb_id = qa[1]
        
        pdb_embedding = '{}.pt'.format(pdb_id)
        pdb_embedding_path = os.path.join(self.pdb_root, pdb_embedding)
        pdb_embedding = torch.load(
            pdb_embedding_path, map_location=torch.device('cpu'))
            # pdb_embedding_path, map_location=torch.device('cuda'))
        pdb_embedding.requires_grad = False

        seq_embedding = '{}.pt'.format(pdb_id)
        seq_embedding_path = os.path.join(self.seq_root, seq_embedding)
        seq_embedding = torch.load(
            seq_embedding_path, map_location=torch.device('cpu'))
            # seq_embedding_path, map_location=torch.device('cuda'))
        seq_embedding.requires_grad = False

        return {
            "q_input": str(qa[0]['Q']),
            "a_input": str(qa[0]['A']),
            "pdb_encoder_out": pdb_embedding,
            "seq_encoder_out": seq_embedding,
            "chain": self.chain,
            "pdb_id": pdb_id 
        }
    
def collater(self, samples):
    max_len_pdb_dim0 = max(pdb_json["pdb_encoder_out"].shape[0] for pdb_json in samples)
    max_len_seq_dim0 = max(pdb_json["seq_encoder_out"].shape[0] for pdb_json in samples)

    for pdb_json in samples:
        pdb_embeddings = pdb_json["pdb_encoder_out"]
        pad_pdb = ((0, max_len_pdb_dim0 - pdb_embeddings.shape[0]), (0, 0), (0, 0))
        pdb_json["pdb_encoder_out"] = torch.tensor(np.pad(pdb_embeddings, pad_pdb, mode='constant'))

        seq_embeddings = pdb_json["seq_encoder_out"]
        pad_seq = ((0, max_len_seq_dim0 - seq_embeddings.shape[0]), (0, 0), (0, 0))
        pdb_json["seq_encoder_out"] = torch.tensor(np.pad(seq_embeddings, pad_seq, mode='constant'))

    return default_collate(samples)