import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.models as models import torchvision.transforms.functional as VF from torchvision import transforms import sys, argparse, os, glob import pandas as pd import numpy as np from PIL import Image from collections import OrderedDict class ToPIL(object): def __call__(self, sample): img = sample img = transforms.functional.to_pil_image(img) return img class BagDataset(): def __init__(self, csv_file, transform=None): self.files_list = csv_file self.transform = transform def __len__(self): return len(self.files_list) def __getitem__(self, idx): temp_path = self.files_list[idx] img = os.path.join(temp_path) img = Image.open(img) img = img.resize((224, 224)) sample = {'input': img} if self.transform: sample = self.transform(sample) return sample class ToTensor(object): def __call__(self, sample): img = sample['input'] img = VF.to_tensor(img) return {'input': img} class Compose(object): def __init__(self, transforms): self.transforms = transforms def __call__(self, img): for t in self.transforms: img = t(img) return img def save_coords(txt_file, csv_file_path): for path in csv_file_path: x, y = path.split('/')[-1].split('.')[0].split('_') txt_file.writelines(str(x) + '\t' + str(y) + '\n') txt_file.close() def adj_matrix(csv_file_path, output, device='cpu'): total = len(csv_file_path) adj_s = np.zeros((total, total)) for i in range(total-1): path_i = csv_file_path[i] x_i, y_i = path_i.split('/')[-1].split('.')[0].split('_') for j in range(i+1, total): # sptial path_j = csv_file_path[j] x_j, y_j = path_j.split('/')[-1].split('.')[0].split('_') if abs(int(x_i)-int(x_j)) <=1 and abs(int(y_i)-int(y_j)) <= 1: adj_s[i][j] = 1 adj_s[j][i] = 1 adj_s = torch.from_numpy(adj_s) adj_s = adj_s.to(device) return adj_s def bag_dataset(args, csv_file_path): transformed_dataset = BagDataset(csv_file=csv_file_path, transform=Compose([ ToTensor() ])) dataloader = DataLoader(transformed_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False) return dataloader, len(transformed_dataset)