|
import os |
|
import logging |
|
import warnings |
|
|
|
from medomni.common.registry import registry |
|
from medomni.datasets.builders.base_dataset_builder import BaseDatasetBuilder |
|
from medomni.datasets.datasets.laion_dataset import LaionDataset |
|
from medomni.datasets.datasets.cc_sbu_dataset import CCSBUDataset, CCSBUAlignDataset |
|
from medomni.datasets.datasets.med_dataset import MedDataset, MedAlignDataset |
|
from torch.utils.data import Dataset |
|
|
|
@registry.register_builder("cc_sbu") |
|
class CCSBUBuilder(BaseDatasetBuilder): |
|
train_dataset_cls = CCSBUDataset |
|
|
|
DATASET_CONFIG_DICT = {"default": "configs/datasets/cc_sbu/defaults.yaml"} |
|
|
|
def _download_ann(self): |
|
pass |
|
|
|
def _download_vis(self): |
|
pass |
|
|
|
def build(self): |
|
self.build_processors() |
|
|
|
build_info = self.config.build_info |
|
|
|
datasets = dict() |
|
split = "train" |
|
|
|
|
|
|
|
dataset_cls = self.train_dataset_cls |
|
datasets[split] = dataset_cls( |
|
vis_processor=self.vis_processors[split], |
|
text_processor=self.text_processors[split], |
|
location=build_info.storage, |
|
).inner_dataset |
|
|
|
return datasets |
|
|
|
|
|
@registry.register_builder("laion") |
|
class LaionBuilder(BaseDatasetBuilder): |
|
train_dataset_cls = LaionDataset |
|
|
|
DATASET_CONFIG_DICT = {"default": "configs/datasets/laion/defaults.yaml"} |
|
|
|
def _download_ann(self): |
|
pass |
|
|
|
def _download_vis(self): |
|
pass |
|
|
|
def build(self): |
|
self.build_processors() |
|
|
|
build_info = self.config.build_info |
|
|
|
datasets = dict() |
|
split = "train" |
|
|
|
|
|
|
|
dataset_cls = self.train_dataset_cls |
|
datasets[split] = dataset_cls( |
|
vis_processor=self.vis_processors[split], |
|
text_processor=self.text_processors[split], |
|
location=build_info.storage, |
|
).inner_dataset |
|
|
|
return datasets |
|
|
|
|
|
@registry.register_builder("cc_sbu_align") |
|
class CCSBUAlignBuilder(BaseDatasetBuilder): |
|
train_dataset_cls = CCSBUAlignDataset |
|
|
|
DATASET_CONFIG_DICT = { |
|
"default": "configs/datasets/cc_sbu/align.yaml", |
|
} |
|
|
|
def build_datasets(self): |
|
|
|
logging.info("Building datasets...") |
|
self.build_processors() |
|
|
|
build_info = self.config.build_info |
|
storage_path = build_info.storage |
|
|
|
datasets = dict() |
|
|
|
if not os.path.exists(storage_path): |
|
warnings.warn("storage path {} does not exist.".format(storage_path)) |
|
|
|
|
|
dataset_cls = self.train_dataset_cls |
|
datasets['train'] = dataset_cls( |
|
vis_processor=self.vis_processors["train"], |
|
text_processor=self.text_processors["train"], |
|
ann_paths=[os.path.join(storage_path, 'filter_cap.json')], |
|
vis_root=os.path.join(storage_path, 'image'), |
|
) |
|
|
|
return datasets |
|
|
|
@registry.register_builder("med") |
|
class MedAlignBuilder(BaseDatasetBuilder): |
|
train_dataset_cls = MedAlignDataset |
|
|
|
DATASET_CONFIG_DICT = { |
|
"default": "configs/datasets/medinterp/align.yaml", |
|
} |
|
|
|
def build_datasets(self): |
|
|
|
logging.info("Building datasets...") |
|
self.build_processors() |
|
|
|
build_info = self.config.build_info |
|
storage_path = build_info.storage |
|
|
|
datasets = dict() |
|
|
|
if not os.path.exists(storage_path): |
|
warnings.warn("storage path {} does not exist.".format(storage_path)) |
|
|
|
|
|
dataset_cls = self.train_dataset_cls |
|
datasets['train'] = dataset_cls( |
|
ann_paths=[os.path.join(storage_path, 'train.json')], |
|
vis_root='/home/zhouhy/physionet.org/files/mimic-cxr-jpg/2.0.0/files', |
|
) |
|
datasets['eval'] = dataset_cls( |
|
ann_paths=[os.path.join(storage_path, 'val.json')], |
|
vis_root='/home/zhouhy/physionet.org/files/mimic-cxr-jpg/2.0.0/files', |
|
) |
|
|
|
return datasets |