LASA / util /simple_image_loader.py
HaolinLiu's picture
update files for demo
18bb538
raw
history blame
2.95 kB
import torch
import torch.nn as nn
from torch.utils import data
import os
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
import glob
def image_transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])
class Image_dataset(data.Dataset):
def __init__(self,dataset_folder="/data1/haolin/datasets",categories=['03001627'],n_px=224):
self.dataset_folder=dataset_folder
self.image_folder=os.path.join(self.dataset_folder,'other_data')
self.preprocess=image_transform(n_px)
self.image_path=[]
for cat in categories:
subpath=os.path.join(self.image_folder,cat,"6_images")
model_list=os.listdir(subpath)
for folder in model_list:
model_folder=os.path.join(subpath,folder)
image_list=os.listdir(model_folder)
for image_filename in image_list:
image_filepath=os.path.join(model_folder,image_filename)
self.image_path.append(image_filepath)
def __len__(self):
return len(self.image_path)
def __getitem__(self,index):
path=self.image_path[index]
basename=os.path.basename(path)[:-4]
model_id=path.split(os.sep)[-2]
category=path.split(os.sep)[-4]
image=Image.open(path)
image_tensor=self.preprocess(image)
return {"images":image_tensor,"image_name":basename,"model_id":model_id,"category":category}
class Image_InTheWild_dataset(data.Dataset):
def __init__(self,dataset_dir="/data1/haolin/data/real_scene_process_data",scene_id="letian-310",n_px=224):
self.dataset_dir=dataset_dir
self.preprocess = image_transform(n_px)
self.image_path = []
if scene_id=="all":
scene_list=os.listdir(self.dataset_dir)
for id in scene_list:
image_folder=os.path.join(self.dataset_dir,id,"6_images")
self.image_path+=glob.glob(image_folder+"/*/*jpg")
else:
image_folder = os.path.join(self.dataset_dir, scene_id, "6_images")
self.image_path += glob.glob(image_folder + "/*/*jpg")
def __len__(self):
return len(self.image_path)
def __getitem__(self,index):
path=self.image_path[index]
basename=os.path.basename(path)[:-4]
model_id=path.split(os.sep)[-2]
scene_id=path.split(os.sep)[-4]
image=Image.open(path)
image_tensor=self.preprocess(image)
return {"images":image_tensor,"image_name":basename,"model_id":model_id,"scene_id":scene_id}