File size: 1,165 Bytes
fe70fd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
from torch.utils.data.dataloader import DataLoader,Dataset
import torch.optim as optim
import albumentations as A
from albumentations.pytorch import ToTensorV2

import numpy as np
import matplotlib.pyplot as plt
import os
from PIL import Image

class Segmentation_Dataset(Dataset):
    def __init__(self,img_dir,mask_dir,transform=None):
        self.img_dir=img_dir
        self.mask_dir=mask_dir
        self.transform=transform
        self.images=os.listdir(img_dir)
        self.images=[im for im in self.images if ".jpg" in im]
    def __len__(self):
        return len(self.images)

    def __getitem__(self,idx):
        img_path=os.path.join(self.img_dir,self.images[idx])
        mask_path=os.path.join(self.mask_dir,self.images[idx].replace(".jpg",".png"))

        image=np.array(Image.open(img_path).convert("RGB"))
        mask=np.array(Image.open(mask_path).convert("L"),dtype=np.float32)
        mask[mask==255]=1.0

        if self.transform is not None:
            augmentations=self.transform(image=image,mask=mask)
            image=augmentations["image"]
            mask=augmentations["mask"]

        return image, mask