Spaces:
Build error
Build error
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torchvision.transforms.functional import to_pil_image | |
| from ..data import ImageFolderDataset | |
| from ..models import create_diffusion_model, create_segmentation_model | |
| from ..utils import get_masked_images | |
| class InpaintPipeline: | |
| def __init__( | |
| self, | |
| segmentation_model_name: str, | |
| diffusion_model_name: str, | |
| control_model_name: str, | |
| images_root: str, | |
| prompts_path: Optional[str] = None, | |
| sd_model_name: Optional[str] = "runwayml/stable-diffusion-inpainting", | |
| image_size: Optional[Tuple[int, int]] = (512, 512), | |
| image_extensions: Optional[Tuple[str]] = (".jpg", ".jpeg", ".png", ".webp"), | |
| segmentation_model_size: Optional[str] = "large", | |
| ): | |
| self.segmentation_model = create_segmentation_model( | |
| segmentation_model_name=segmentation_model_name, | |
| model_size=segmentation_model_size, | |
| ) | |
| self.diffusion_model = create_diffusion_model( | |
| diffusion_model_name=diffusion_model_name, | |
| control_model_name=control_model_name, | |
| sd_model_name=sd_model_name, | |
| ) | |
| self.data_loader = self.build_data_loader( | |
| images_root=images_root, | |
| prompts_path=prompts_path, | |
| image_size=image_size, | |
| image_extensions=image_extensions, | |
| ) | |
| def build_data_loader( | |
| self, | |
| images_root: str, | |
| prompts_path: Optional[str] = None, | |
| image_size: Optional[Tuple[int, int]] = (512, 512), | |
| image_extensions: Optional[Tuple[str]] = (".jpg", ".jpeg", ".png", ".webp"), | |
| batch_size: Optional[int] = 1, | |
| ) -> DataLoader: | |
| dataset = ImageFolderDataset( | |
| images_root, prompts_path, image_size, image_extensions | |
| ) | |
| data_loader = DataLoader( | |
| dataset, batch_size=batch_size, shuffle=False, num_workers=8 | |
| ) | |
| return data_loader | |
| def run(self, data_loader: Optional[DataLoader] = None) -> List[Dict[str, Any]]: | |
| if data_loader is not None: | |
| self.data_loader = data_loader | |
| results = [] | |
| for idx, (images, prompts) in enumerate(self.data_loader): | |
| images = [to_pil_image(img) for img in images] | |
| semantic_maps = self.segmentation_model.process(images) | |
| object_masks = [ | |
| get_object_mask(seg_map, class_id=0) for seg_map in semantic_maps | |
| ] | |
| outputs = self.diffusion_model.process( | |
| images=images, | |
| prompts=[prompts[0]], | |
| mask_images=object_masks, | |
| negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", | |
| ) | |
| results += outputs["output_images"] | |
| return results | |