File size: 1,608 Bytes
a2dba58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from torch.utils.data import Dataset
from pathlib import Path
from typing import List, Tuple, TypeVar
import pandas as pd
import os
from PIL import Image
from torch import Tensor
from torchvision import transforms


PathLike = TypeVar("PathLike", str, Path, None)
# can be found at https://www.kaggle.com/datasets/nih-chest-xrays/data

class NIHDataset(Dataset):
    def __init__(self, base_path:PathLike, 
                 csv_path:PathLike,  
                 csv_name:str, 
                 img_size:int=128,
                 labels:List[str] =('right lung', 'left lung', ),
                 **kwargs) -> None:
        self.df = pd.read_csv(os.path.join(csv_path, csv_name))
        self.base_path = Path(base_path)
        self.labels = labels 
        self.img_size = img_size


    def load_image(self, fname: str) -> Tensor:
        img = Image.open(self.base_path /fname).convert('L').resize((self.img_size, self.img_size))
        img = transforms.ToTensor()(img).float()
        return img

    def load_labels(self, fname: str) -> Tensor:

        label = Image.open(self.base_path /fname).convert('L').resize((self.img_size, self.img_size))
        # convert to tensor
        label = transforms.ToTensor()(label).float()
        # make binary
        label = (label > .5).float()
        
        return label

    def __getitem__(self, index) -> Tuple[Tensor, Tensor]:
        i = self.df.index[index]
        img = self.load_image(self.df.loc[i, "scan"])
        labels = self.load_labels(self.df.loc[i, "mask"])

        return img, labels

    def __len__(self):
        return len(self.df)