import torch from diffusers import AutoPipelineForInpainting from torchvision.transforms.functional import to_pil_image from PIL import Image class InpaintingModel: def __init__(self) -> None: pass def generate(self, image: torch.Tensor, mask_image: torch.Tensor, prompt: str) -> Image.Image: pass class KandingskyInpaintingModel(InpaintingModel): def __init__( self, device = torch.device("cpu"), ) -> None: super().__init__() self.device = device self.model = AutoPipelineForInpainting.from_pretrained("kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16) self.model.enable_model_cpu_offload() self.negative_prompt = "deformed, ugly, disfigured" def generate(self, image: Image.Image, mask_image: Image.Image, prompt: str) -> Image.Image: output = self.model(prompt=prompt, negative_prompt=self.negative_prompt, image=image, mask_image=mask_image) return output.images[0]