File size: 3,546 Bytes
a2dba58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import List, Tuple, TypeVar, Optional
import pandas as pd
import os
from PIL import Image
from torch import Tensor
from torchvision import transforms
import torch

PathLike = TypeVar("PathLike", str, Path, None)

PROJECT_DIR = Path(os.path.realpath(__file__)).parent.parent
DATADIR = Path("<PATH_TO_DATA>/JSRT")
# can be found at http://db.jsrt.or.jp/eng.php

def build_dataloaders(
        data_dir: str=DATADIR,
        img_size: int=128,
        batch_size: int=16,
        num_workers: int=1,
        n_labelled_images: Optional[int] = None,
        **kwargs
) -> Tuple[List, List, List]:
    """
    Build dataloaders for the JSRT dataset.
    """
    train_ds = JSRTDataset(data_dir, PROJECT_DIR / "data", "JSRT_train_split.csv", img_size)
    if n_labelled_images is not None:
        train_ds = torch.utils.data.Subset(train_ds, range(n_labelled_images))
        print(f"Using {n_labelled_images} labelled images")
    val_ds = JSRTDataset(data_dir, PROJECT_DIR / "data", "JSRT_val_split.csv", img_size)
    test_ds = JSRTDataset(data_dir, PROJECT_DIR / "data", "JSRT_test_split.csv", img_size)

    dataloaders = {}
    dataloaders['train'] = DataLoader(train_ds, batch_size=batch_size,
                                      shuffle=True, num_workers=num_workers,
                                      pin_memory=True)
    dataloaders['val'] = DataLoader(val_ds, batch_size=batch_size,
                                    shuffle=False, num_workers=num_workers,
                                    pin_memory=True)
    dataloaders['test'] = DataLoader(test_ds, batch_size=batch_size,
                                     shuffle=False, num_workers=num_workers,
                                     pin_memory=True)

    return dataloaders


class JSRTDataset(Dataset):
    def __init__(self, base_path:PathLike, 
                 csv_path:PathLike,  
                 csv_name:str, 
                 img_size:int=128,
                 labels:List[str] =('right lung', 'left lung', ),
                 **kwargs) -> None:
        self.df = pd.read_csv(os.path.join(csv_path, csv_name))
        self.base_path = Path(base_path)
        self.labels = labels 
        self.img_size = img_size


    def load_image(self, fname: str) -> Tensor:
        img = Image.open(self.base_path /fname).convert('L').resize((self.img_size, self.img_size))
        img = transforms.ToTensor()(img).float()
        return img

    def load_labels(self, fnames: List[str]) -> Tensor:
        labels = []
        for fname in fnames:
            label = Image.open(self.base_path /fname).convert('L').resize((self.img_size, self.img_size))
            # convert to tensor
            label = transforms.ToTensor()(label).float()
            # make binary
            label = (label > .5).float()
            labels.append(label)
        # append all labels and merge
        label = torch.stack(labels).sum(0)
        # lungs have no overlap (right?)
        if (label > 1).sum()>0:
            print("overlapping lungs!", fnames)
            label = (label > .5)
        return label

    def __getitem__(self, index) -> Tuple[Tensor, Tensor]:
        i = self.df.index[index]
        img = self.load_image(self.df.loc[i, "path"])
        
        label_paths = ["SCR/masks/" + item + "/" + self.df.loc[i, 'id']+ ".gif" for item in self.labels]
        labels = self.load_labels(label_paths)

        return img, labels

    def __len__(self):
        return len(self.df)