File size: 2,108 Bytes
a2dba58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import List, Tuple, TypeVar, Optional
import pandas as pd
import os
from PIL import Image
from torch import Tensor
from torchvision import transforms
import torch

PathLike = TypeVar("PathLike", str, Path, None)
# can be found at https://data.lhncbc.nlm.nih.gov/public/Tuberculosis-Chest-X-ray-Datasets/Montgomery-County-CXR-Set/MontgomerySet/index.html


class MonDataset(Dataset):
    def __init__(self, base_path:PathLike, 
                 csv_path:PathLike,  
                 csv_name:str, 
                 img_size:int=128,
                 labels:List[str] =('right lung', 'left lung', ),
                 **kwargs) -> None:
        self.df = pd.read_csv(os.path.join(csv_path, csv_name))
        self.base_path = Path(base_path)
        self.labels = labels 
        self.img_size = img_size


    def load_image(self, fname: str) -> Tensor:
        img = Image.open(self.base_path /fname).convert('L').resize((self.img_size, self.img_size))
        img = transforms.ToTensor()(img).float()
        return img


    def load_labels(self, fnames: List[str]) -> Tensor:
        labels = []
        for fname in fnames:
            label = Image.open(self.base_path /fname).convert('L').resize((self.img_size, self.img_size))
            # convert to tensor
            label = transforms.ToTensor()(label).float()
            # make binary
            label = (label > .5).float()
            labels.append(label)
        # append all labels and merge
        label = torch.stack(labels).sum(0)
        # lungs have no overlap (right?)
        if (label > 1).sum()>0:
            print("overlapping lungs!", fnames)
            label = (label > .5)
        return label

    def __getitem__(self, index) -> Tuple[Tensor, Tensor]:
        i = self.df.index[index]
        img = self.load_image(self.df.loc[i, "scan"])
        
        fnames = [ self.df.loc[i, l] for l in self.labels]
        labels = self.load_labels(fnames)

        return img, labels

    def __len__(self) -> int:
        return len(self.df)