|
from random import shuffle |
|
import torch |
|
import csv, os |
|
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, Dataset, SequentialSampler |
|
from sklearn.model_selection import train_test_split |
|
from torchvision.io import read_image |
|
import torch.nn as nn |
|
from torchvision import transforms |
|
import pandas as pd |
|
import numpy as np |
|
from PIL import Image |
|
import math |
|
from transformers import AutoImageProcessor |
|
|
|
class imgDataset(Dataset): |
|
def __init__(self, path, mode='train', use_processor=True): |
|
self.path = path |
|
self.mode = mode |
|
self.use_processor = use_processor |
|
self.image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") |
|
self.transform = { |
|
'train': transforms.Compose([ |
|
transforms.RandomResizedCrop(224), |
|
transforms.RandomHorizontalFlip(), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]), |
|
'val': transforms.Compose([ |
|
transforms.Resize(256), |
|
transforms.CenterCrop(224), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
]) |
|
} |
|
self.trans = self.transform[mode] |
|
self.data = self.get_data() |
|
|
|
def convert_body_to_int(self, pos, file_name_list): |
|
body_str = file_name_list[1].split('-')[pos] |
|
if not body_str: body_str = '62' |
|
body = int(body_str[1:3]) if not body_str.isdigit() else int(body_str) |
|
body = 100+body if body <= 25 else body |
|
return body |
|
|
|
def get_data(self): |
|
data = [] |
|
with open(self.path, 'r', encoding='utf-8') as f: |
|
for line in f.readlines(): |
|
file_name_list = line.split(' ') |
|
if not self.mode in file_name_list:continue |
|
label, h = 0 if file_name_list[2]=="big" else 1, float(file_name_list[3]) |
|
b = self.convert_body_to_int(0, file_name_list) |
|
w = self.convert_body_to_int(1, file_name_list) |
|
hh = self.convert_body_to_int(2, file_name_list) |
|
data.append([os.path.join('images', file_name_list[0], file_name_list[2], file_name_list[1]), label, h, b, w, hh]) |
|
return data |
|
|
|
def __len__(self): |
|
return len(self.data) |
|
|
|
def __getitem__(self, idx): |
|
img_path, label, h, b, w, hh = self.data[idx] |
|
inp_img = Image.open(img_path).convert("RGB") |
|
if not self.use_processor: image_tensor = self.trans(inp_img) |
|
else:image_tensor = self.image_processor(images=inp_img, return_tensors="pt") |
|
return image_tensor, label, torch.tensor(h, dtype=torch.float), torch.tensor(b, dtype=torch.float), torch.tensor(w, dtype=torch.float), torch.tensor(hh, dtype=torch.float) |
|
|
|
if __name__ == "__main__": |
|
train_dataset = imgDataset('labels.txt', mode='train') |
|
test_dataset = imgDataset('labels.txt', mode='val') |
|
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True) |
|
print(len(train_dataset), len(test_dataset)) |
|
print(next(iter(train_dataloader))) |