import os |
import torch |
import torch.utils.data as data |
import numpy as np |
from PIL import Image, ImageFile |
import random |
from torchvision.transforms import ToTensor |
from torchvision import transforms |
import cv2 |
import pickle |
import torch.nn.functional as F |
def collate_features(batch): |
img = torch.cat([item[0] for item in batch], dim = 0) |
coords = np.vstack([item[1] for item in batch]) |
return [img, coords] |
def eval_transforms(pretrained=False): |
if pretrained: |
mean = (0.485, 0.456, 0.406) |
std = (0.229, 0.224, 0.225) |
else: |
mean = (0.5,0.5,0.5) |
std = (0.5,0.5,0.5) |
trnsfrms_val = transforms.Compose( |
[ |
transforms.Resize(224), |
transforms.ToTensor(), |
transforms.Normalize(mean = mean, std = std) |
] |
) |
return trnsfrms_val |
class GraphDataset(data.Dataset): |
"""input and label image dataset""" |
def __init__(self, root, ids, metadata_path, target_patch_size=-1): |
super(GraphDataset, self).__init__() |
""" |
Args: |
fileDir(string): directory with all the input images. |
transform(callable, optional): Optional transform to be applied on a sample |
""" |
self.root = root |
self.ids = ids |
self.classdict = pickle.load(open(os.path.join(metadata_path, 'label_map.pkl'), 'rb' )) |
self._up_kwargs = {'mode': 'bilinear'} |
def __getitem__(self, index): |
sample = {} |
info = self.ids[index].replace('\n', '') |
file_name, label = info.split('\t')[0], info.split('\t')[1] |
sample['label'] = self.classdict[label] |
sample['id'] = file_name |
file_path = os.path.join(self.root, 'simclr_files') |
feature_path = os.path.join(file_path, file_name, 'features.pt') |
if os.path.exists(feature_path): |
features = torch.load(feature_path, map_location=lambda storage, loc: storage) |
else: |
print(feature_path + ' not exists') |
features = torch.zeros(1, 512) |
adj_s_path = os.path.join(file_path, file_name, 'adj_s.pt') |
if os.path.exists(adj_s_path): |
adj_s = torch.load(adj_s_path, map_location=lambda storage, loc: storage) |
else: |
print(adj_s_path + ' not exists') |
adj_s = torch.ones(features.shape[0], features.shape[0]) |
sample['image'] = features |
sample['adj_s'] = adj_s |
return sample |
def __len__(self): |
return len(self.ids) |
''' def __getitem__(self, index): |
sample = {} |
info = self.ids[index].replace('\n', '') |
file_name, label = info.split('\t')[0].rsplit('.', 1)[0], info.split('\t')[1] |
site, file_name = file_name.split('/') |
# if site =='CCRCC': |
# file_path = self.root + 'CPTAC_CCRCC_features/simclr_files' |
if site =='LUAD' or site =='LSCC': |
site = 'LUNG' |
file_path = self.root + 'CPTAC_{}_features/simclr_files'.format(site) #_pre# with # rushin |
# For NLST only |
if site =='NLST': |
file_path = self.root + 'NLST_Lung_features/simclr_files' |
# For TCGA only |
if site =='TCGA': |
file_name = info.split('\t')[0] |
_, file_name = file_name.split('/') |
file_path = self.root + 'TCGA_LUNG_features/simclr_files' #_resnet_with |
sample['label'] = self.classdict[label] |
sample['id'] = file_name |
#feature_path = os.path.join(self.root, file_name, 'features.pt') |
feature_path = os.path.join(file_path, file_name, 'features.pt') |
if os.path.exists(feature_path): |
features = torch.load(feature_path, map_location=lambda storage, loc: storage) |
else: |
print(feature_path + ' not exists') |
features = torch.zeros(1, 512) |
#adj_s_path = os.path.join(self.root, file_name, 'adj_s.pt') |
adj_s_path = os.path.join(file_path, file_name, 'adj_s.pt') |
if os.path.exists(adj_s_path): |
adj_s = torch.load(adj_s_path, map_location=lambda storage, loc: storage) |
else: |
print(adj_s_path + ' not exists') |
adj_s = torch.ones(features.shape[0], features.shape[0]) |
#features = features.unsqueeze(0) |
sample['image'] = features |
sample['adj_s'] = adj_s #adj_s.to(torch.double) |
# return {'image': image.astype(np.float32), 'label': label.astype(np.int64)} |
return sample |
''' |