File size: 6,193 Bytes
85ab89d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import sys

import torch
from torch.utils.data import Dataset
import json
import numpy as np
from torch.utils.data.dataloader import default_collate

import time


class ESMDataset(Dataset):
    def __init__(self, pdb_root, seq_root, ann_paths, dataset_description, chain="A"):
        """
        pdb_root (string): Root directory of protein pdb embeddings (e.g. xyz/pdb/)
        seq_root (string): Root directory of sequences embeddings (e.g. xyz/seq/)
        ann_root (string): directory to store the annotation file
        dataset_description (string): json file that describes what data are used for training/testing
		"""
        data_describe = json.load(open(dataset_description, "r"))
        train_set = set(data_describe["train"])
        self.pdb_root = pdb_root
        self.seq_root = seq_root
        self.annotation = json.load(open(ann_paths, "r"))
        keep = []
        for i in range(0, len(self.annotation)):
            if (self.annotation[i]["pdb_id"] in train_set):
                keep.append(self.annotation[i])
        self.annotation = keep
        self.pdb_ids = {}
        self.chain = chain

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):

        ann = self.annotation[index]

        pdb_embedding = '{}.pt'.format(ann["pdb_id"])
        pdb_embedding_path = os.path.join(self.pdb_root, pdb_embedding)
        pdb_embedding = torch.load(
            pdb_embedding_path, map_location=torch.device('cpu'))
            # pdb_embedding_path, map_location=torch.device('cuda'))
        pdb_embedding.requires_grad = False

        seq_embedding = '{}.pt'.format(ann["pdb_id"])
        seq_embedding_path = os.path.join(self.seq_root, seq_embedding)
        seq_embedding = torch.load(
            seq_embedding_path, map_location=torch.device('cpu'))
            # seq_embedding_path, map_location=torch.device('cuda'))
        seq_embedding.requires_grad = False

        caption = ann["caption"]

        return {
            "text_input": caption,
            "pdb_encoder_out": pdb_embedding,
            "seq_encoder_out": seq_embedding,
            "chain": self.chain,
            "pdb_id": ann["pdb_id"]
        }

    # Yijia please check :)
    # def collater(self, samples):
    #     # print(samples)
    #     max_len_pdb_dim0 = -1
    #     max_len_seq_dim0 = -1

    #     for pdb_json in samples:
    #         pdb_embeddings = pdb_json["pdb_encoder_out"]
    #         shape_dim0 = pdb_embeddings.shape[0]
    #         max_len_pdb_dim0 = max(max_len_pdb_dim0, shape_dim0)

    #         seq_embeddings = pdb_json["seq_encoder_out"]
    #         shape_dim0 = seq_embeddings.shape[0]
    #         max_len_seq_dim0 = max(max_len_seq_dim0, shape_dim0)

    #     for pdb_json in samples:
    #         pdb_embeddings = pdb_json["pdb_encoder_out"]
    #         shape_dim0 = pdb_embeddings.shape[0]
    #         pad1 = ((0, max_len_pdb_dim0 - shape_dim0), (0, 0), (0, 0))
    #         arr1_padded = np.pad(pdb_embeddings, pad1, mode='constant', )
    #         pdb_json["pdb_encoder_out"] = arr1_padded

    #         seq_embeddings = pdb_json["seq_encoder_out"]
    #         shape_dim0 = seq_embeddings.shape[0]
    #         pad1 = ((0, max_len_seq_dim0 - shape_dim0), (0, 0), (0, 0))
    #         arr1_padded = np.pad(seq_embeddings, pad1, mode='constant', )
    #         pdb_json["seq_encoder_out"] = arr1_padded

    #     print(samples[0].keys())
    #     return default_collate(samples)

def collater(self, samples):
    max_len_pdb_dim0 = max(pdb_json["pdb_encoder_out"].shape[0] for pdb_json in samples)
    max_len_seq_dim0 = max(pdb_json["seq_encoder_out"].shape[0] for pdb_json in samples)

    for pdb_json in samples:
        pdb_embeddings = pdb_json["pdb_encoder_out"]
        pad_pdb = ((0, max_len_pdb_dim0 - pdb_embeddings.shape[0]), (0, 0), (0, 0))
        pdb_json["pdb_encoder_out"] = torch.tensor(np.pad(pdb_embeddings, pad_pdb, mode='constant'))

        seq_embeddings = pdb_json["seq_encoder_out"]
        pad_seq = ((0, max_len_seq_dim0 - seq_embeddings.shape[0]), (0, 0), (0, 0))
        pdb_json["seq_encoder_out"] = torch.tensor(np.pad(seq_embeddings, pad_seq, mode='constant'))

    return default_collate(samples)

# import os
# import sys

# import torch
# from torch.utils.data import Dataset
# import json
# import numpy as np
# from torch.utils.data.dataloader import default_collate

# import time

# class ESMDataset(Dataset):
#     def __init__(self, pdb_root, ann_paths, chain="A"):
#         """
#         protein (string): Root directory of protein (e.g. coco/images/)
#         ann_root (string): directory to store the annotation file
#         """
#         self.pdb_root = pdb_root
#         self.annotation = json.load(open(ann_paths, "r"))
#         self.pdb_ids = {}
#         self.chain = chain

#     def __len__(self):
#         return len(self.annotation)

#     def __getitem__(self, index):

#         ann = self.annotation[index]

#         protein_embedding = '{}.pt'.format(ann["pdb_id"])

#         protein_embedding_path = os.path.join(self.pdb_root, protein_embedding)
#         protein_embedding = torch.load(protein_embedding_path, map_location=torch.device('cpu'))
#         protein_embedding.requires_grad = False
#         caption = ann["caption"]

#         return {
#             "text_input": caption,
#             "encoder_out": protein_embedding,
#             "chain": self.chain,
#             "pdb_id": ann["pdb_id"]
#         }

#     def collater(self, samples):
#         max_len_protein_dim0 = -1
#         for pdb_json in samples:
#             pdb_embeddings = pdb_json["encoder_out"]
#             shape_dim0 = pdb_embeddings.shape[0]
#             max_len_protein_dim0 = max(max_len_protein_dim0, shape_dim0)
#         for pdb_json in samples:
#             pdb_embeddings = pdb_json["encoder_out"]
#             shape_dim0 = pdb_embeddings.shape[0]
#             pad1 = ((0, max_len_protein_dim0 - shape_dim0), (0, 0), (0, 0))
#             arr1_padded = np.pad(pdb_embeddings, pad1, mode='constant', )
#             pdb_json["encoder_out"] = arr1_padded

#         return default_collate(samples)