EdwardoSunny's picture
finished
85ab89d
raw
history blame
1.84 kB
import os
import sys
import torch
from torch.utils.data import Dataset
import json
import numpy as np
from torch.utils.data.dataloader import default_collate
import time
class ESMDataset(Dataset):
def __init__(self, pdb_root, ann_paths, chain="A"):
"""
protein (string): Root directory of protein (e.g. coco/images/)
ann_root (string): directory to store the annotation file
"""
self.pdb_root = pdb_root
self.annotation = json.load(open(ann_paths, "r"))
self.pdb_ids = {}
self.chain = chain
def __len__(self):
return len(self.annotation)
def __getitem__(self, index):
ann = self.annotation[index]
protein_embedding = '{}.pt'.format(ann["pdb_id"])
protein_embedding_path = os.path.join(self.pdb_root, protein_embedding)
protein_embedding = torch.load(protein_embedding_path, map_location=torch.device('cpu'))
protein_embedding.requires_grad = False
caption = ann["caption"]
return {
"text_input": caption,
"encoder_out": protein_embedding,
"chain": self.chain,
"pdb_id": ann["pdb_id"]
}
def collater(self, samples):
max_len_protein_dim0 = -1
for pdb_json in samples:
pdb_embeddings = pdb_json["encoder_out"]
shape_dim0 = pdb_embeddings.shape[0]
max_len_protein_dim0 = max(max_len_protein_dim0, shape_dim0)
for pdb_json in samples:
pdb_embeddings = pdb_json["encoder_out"]
shape_dim0 = pdb_embeddings.shape[0]
pad1 = ((0, max_len_protein_dim0 - shape_dim0), (0, 0), (0, 0))
arr1_padded = np.pad(pdb_embeddings, pad1, mode='constant', )
pdb_json["encoder_out"] = arr1_padded
return default_collate(samples)