TEDM-demo / dataloaders /Montgomery.py
anonymous
first commit without models
a2dba58
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import List, Tuple, TypeVar, Optional
import pandas as pd
import os
from PIL import Image
from torch import Tensor
from torchvision import transforms
import torch
PathLike = TypeVar("PathLike", str, Path, None)
# can be found at https://data.lhncbc.nlm.nih.gov/public/Tuberculosis-Chest-X-ray-Datasets/Montgomery-County-CXR-Set/MontgomerySet/index.html
class MonDataset(Dataset):
def __init__(self, base_path:PathLike,
csv_path:PathLike,
csv_name:str,
img_size:int=128,
labels:List[str] =('right lung', 'left lung', ),
**kwargs) -> None:
self.df = pd.read_csv(os.path.join(csv_path, csv_name))
self.base_path = Path(base_path)
self.labels = labels
self.img_size = img_size
def load_image(self, fname: str) -> Tensor:
img = Image.open(self.base_path /fname).convert('L').resize((self.img_size, self.img_size))
img = transforms.ToTensor()(img).float()
return img
def load_labels(self, fnames: List[str]) -> Tensor:
labels = []
for fname in fnames:
label = Image.open(self.base_path /fname).convert('L').resize((self.img_size, self.img_size))
# convert to tensor
label = transforms.ToTensor()(label).float()
# make binary
label = (label > .5).float()
labels.append(label)
# append all labels and merge
label = torch.stack(labels).sum(0)
# lungs have no overlap (right?)
if (label > 1).sum()>0:
print("overlapping lungs!", fnames)
label = (label > .5)
return label
def __getitem__(self, index) -> Tuple[Tensor, Tensor]:
i = self.df.index[index]
img = self.load_image(self.df.loc[i, "scan"])
fnames = [ self.df.loc[i, l] for l in self.labels]
labels = self.load_labels(fnames)
return img, labels
def __len__(self) -> int:
return len(self.df)