Ai_Portrait_Mode / dataset.py
Heisenberg08's picture
added model and code
fe70fd4
raw
history blame
No virus
1.17 kB
import torch
from torch.utils.data.dataloader import DataLoader,Dataset
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image
class Segmentation_Dataset(Dataset):
def __init__(self,img_dir,mask_dir,transform=None):
self.img_dir=img_dir
self.mask_dir=mask_dir
self.transform=transform
self.images=os.listdir(img_dir)
self.images=[im for im in self.images if ".jpg" in im]
def __len__(self):
return len(self.images)
def __getitem__(self,idx):
img_path=os.path.join(self.img_dir,self.images[idx])
mask_path=os.path.join(self.mask_dir,self.images[idx].replace(".jpg",".png"))
image=np.array(Image.open(img_path).convert("RGB"))
mask=np.array(Image.open(mask_path).convert("L"),dtype=np.float32)
mask[mask==255]=1.0
if self.transform is not None:
augmentations=self.transform(image=image,mask=mask)
image=augmentations["image"]
mask=augmentations["mask"]
return image, mask