kingpreyansh's picture
Added Stable Diffusion Files
fe0ca90
raw
history blame
No virus
10.2 kB
from pathlib import Path
from typing import Optional, List, Callable, Dict, Any, Union
import warnings
import PIL.Image as pil_image
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
from taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder
from taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder
from taming.data.conditional_builder.utils import load_object_from_string
from taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType
from taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \
Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor
class AnnotatedObjectsDataset(Dataset):
def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int,
min_object_area: float, min_objects_per_image: int, max_objects_per_image: int,
crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool,
encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "",
no_object_classes: Optional[int] = None):
self.data_path = data_path
self.split = split
self.keys = keys
self.target_image_size = target_image_size
self.min_object_area = min_object_area
self.min_objects_per_image = min_objects_per_image
self.max_objects_per_image = max_objects_per_image
self.crop_method = crop_method
self.random_flip = random_flip
self.no_tokens = no_tokens
self.use_group_parameter = use_group_parameter
self.encode_crop = encode_crop
self.annotations = None
self.image_descriptions = None
self.categories = None
self.category_ids = None
self.category_number = None
self.image_ids = None
self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip)
self.paths = self.build_paths(self.data_path)
self._conditional_builders = None
self.category_allow_list = None
if category_allow_list_target:
allow_list = load_object_from_string(category_allow_list_target)
self.category_allow_list = {name for name, _ in allow_list}
self.category_mapping = {}
if category_mapping_target:
self.category_mapping = load_object_from_string(category_mapping_target)
self.no_object_classes = no_object_classes
def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]:
top_level = Path(top_level)
sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()}
for path in sub_paths.values():
if not path.exists():
raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.')
return sub_paths
@staticmethod
def load_image_from_disk(path: Path) -> Image:
return pil_image.open(path).convert('RGB')
@staticmethod
def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool):
transform_functions = []
if crop_method == 'none':
transform_functions.append(transforms.Resize((target_image_size, target_image_size)))
elif crop_method == 'center':
transform_functions.extend([
transforms.Resize(target_image_size),
CenterCropReturnCoordinates(target_image_size)
])
elif crop_method == 'random-1d':
transform_functions.extend([
transforms.Resize(target_image_size),
RandomCrop1dReturnCoordinates(target_image_size)
])
elif crop_method == 'random-2d':
transform_functions.extend([
Random2dCropReturnCoordinates(target_image_size),
transforms.Resize(target_image_size)
])
elif crop_method is None:
return None
else:
raise ValueError(f'Received invalid crop method [{crop_method}].')
if random_flip:
transform_functions.append(RandomHorizontalFlipReturn())
transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.))
return transform_functions
def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor):
crop_bbox = None
flipped = None
for t in self.transform_functions:
if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)):
crop_bbox, x = t(x)
elif isinstance(t, RandomHorizontalFlipReturn):
flipped, x = t(x)
else:
x = t(x)
return crop_bbox, flipped, x
@property
def no_classes(self) -> int:
return self.no_object_classes if self.no_object_classes else len(self.categories)
@property
def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder:
# cannot set this up in init because no_classes is only known after loading data in init of superclass
if self._conditional_builders is None:
self._conditional_builders = {
'objects_center_points': ObjectsCenterPointsConditionalBuilder(
self.no_classes,
self.max_objects_per_image,
self.no_tokens,
self.encode_crop,
self.use_group_parameter,
getattr(self, 'use_additional_parameters', False)
),
'objects_bbox': ObjectsBoundingBoxConditionalBuilder(
self.no_classes,
self.max_objects_per_image,
self.no_tokens,
self.encode_crop,
self.use_group_parameter,
getattr(self, 'use_additional_parameters', False)
)
}
return self._conditional_builders
def filter_categories(self) -> None:
if self.category_allow_list:
self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list}
if self.category_mapping:
self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping}
def setup_category_id_and_number(self) -> None:
self.category_ids = list(self.categories.keys())
self.category_ids.sort()
if '/m/01s55n' in self.category_ids:
self.category_ids.remove('/m/01s55n')
self.category_ids.append('/m/01s55n')
self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)}
if self.category_allow_list is not None and self.category_mapping is None \
and len(self.category_ids) != len(self.category_allow_list):
warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. '
'Make sure all names in category_allow_list exist.')
def clean_up_annotations_and_image_descriptions(self) -> None:
image_id_set = set(self.image_ids)
self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set}
self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set}
@staticmethod
def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float,
min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]:
filtered = {}
for image_id, annotations in all_annotations.items():
annotations_with_min_area = [a for a in annotations if a.area > min_object_area]
if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image:
filtered[image_id] = annotations_with_min_area
return filtered
def __len__(self):
return len(self.image_ids)
def __getitem__(self, n: int) -> Dict[str, Any]:
image_id = self.get_image_id(n)
sample = self.get_image_description(image_id)
sample['annotations'] = self.get_annotation(image_id)
if 'image' in self.keys:
sample['image_path'] = str(self.get_image_path(image_id))
sample['image'] = self.load_image_from_disk(sample['image_path'])
sample['image'] = convert_pil_to_tensor(sample['image'])
sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image'])
sample['image'] = sample['image'].permute(1, 2, 0)
for conditional, builder in self.conditional_builders.items():
if conditional in self.keys:
sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped'])
if self.keys:
# only return specified keys
sample = {key: sample[key] for key in self.keys}
return sample
def get_image_id(self, no: int) -> str:
return self.image_ids[no]
def get_annotation(self, image_id: str) -> str:
return self.annotations[image_id]
def get_textual_label_for_category_id(self, category_id: str) -> str:
return self.categories[category_id].name
def get_textual_label_for_category_no(self, category_no: int) -> str:
return self.categories[self.get_category_id(category_no)].name
def get_category_number(self, category_id: str) -> int:
return self.category_number[category_id]
def get_category_id(self, category_no: int) -> str:
return self.category_ids[category_no]
def get_image_description(self, image_id: str) -> Dict[str, Any]:
raise NotImplementedError()
def get_path_structure(self):
raise NotImplementedError
def get_image_path(self, image_id: str) -> Path:
raise NotImplementedError