"""This file contains methods for inference and image generation.""" import logging from typing import List, Tuple, Dict import streamlit as st import torch import gc import time import numpy as np from PIL import Image from time import perf_counter from contextlib import contextmanager from scipy.signal import fftconvolve from PIL import ImageFilter from transformers import AutoImageProcessor, UperNetForSemanticSegmentation from diffusers import ControlNetModel, UniPCMultistepScheduler from diffusers import StableDiffusionInpaintPipeline from config import WIDTH, HEIGHT from palette import ade_palette from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline LOGGING = logging.getLogger(__name__) def flush(): gc.collect() torch.cuda.empty_cache() class ControlNetPipeline: def __init__(self): self.in_use = False self.controlnet = ControlNetModel.from_pretrained( "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16) self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", controlnet=self.controlnet, safety_checker=None, torch_dtype=torch.float16 ) self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) self.pipe.enable_xformers_memory_efficient_attention() self.pipe = self.pipe.to("cuda") self.waiting_queue = [] self.count = 0 @property def queue_size(self): return len(self.waiting_queue) def __call__(self, **kwargs): self.count += 1 number = self.count self.waiting_queue.append(number) # wait until the next number in the queue is the current number while self.waiting_queue[0] != number: print(f"Wait for your turn {number} in queue {self.waiting_queue}") time.sleep(0.5) pass # it's your turn, so remove the number from the queue # and call the function print("It's the turn of", self.count) results = self.pipe(**kwargs) self.waiting_queue.pop(0) flush() return results class SDPipeline: def __init__(self): self.pipe = StableDiffusionInpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-2-inpainting", torch_dtype=torch.float16, safety_checker=None, ) self.pipe.enable_xformers_memory_efficient_attention() self.pipe = self.pipe.to("cuda") self.waiting_queue = [] self.count = 0 @property def queue_size(self): return len(self.waiting_queue) def __call__(self, **kwargs): self.count += 1 number = self.count self.waiting_queue.append(number) # wait until the next number in the queue is the current number while self.waiting_queue[0] != number: print(f"Wait for your turn {number} in queue {self.waiting_queue}") time.sleep(0.5) pass # it's your turn, so remove the number from the queue # and call the function print("It's the turn of", self.count) results = self.pipe(**kwargs) self.waiting_queue.pop(0) flush() return results def convolution(mask: Image.Image, size=9) -> Image: """Method to blur the mask Args: mask (Image): masking image size (int, optional): size of the blur. Defaults to 9. Returns: Image: blurred mask """ mask = np.array(mask.convert("L")) conv = np.ones((size, size)) / size**2 mask_blended = fftconvolve(mask, conv, 'same') mask_blended = mask_blended.astype(np.uint8).copy() border = size # replace borders with original values mask_blended[:border, :] = mask[:border, :] mask_blended[-border:, :] = mask[-border:, :] mask_blended[:, :border] = mask[:, :border] mask_blended[:, -border:] = mask[:, -border:] return Image.fromarray(mask_blended).convert("L") def postprocess_image_masking(inpainted: Image, image: Image, mask: Image) -> Image: """Method to postprocess the inpainted image Args: inpainted (Image): inpainted image image (Image): original image mask (Image): mask Returns: Image: inpainted image """ final_inpainted = Image.composite(inpainted.convert("RGBA"), image.convert("RGBA"), mask) return final_inpainted.convert("RGB") @st.experimental_singleton(max_entries=5) def get_controlnet() -> ControlNetModel: """Method to load the controlnet model Returns: ControlNetModel: controlnet model """ pipe = ControlNetPipeline() return pipe @st.experimental_singleton(max_entries=5) def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: """Method to load the segmentation pipeline Returns: Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline """ image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") image_segmentor = UperNetForSemanticSegmentation.from_pretrained( "openmmlab/upernet-convnext-small") return image_processor, image_segmentor @st.experimental_singleton(max_entries=5) def get_inpainting_pipeline() -> StableDiffusionInpaintPipeline: """Method to load the inpainting pipeline Returns: StableDiffusionInpaintPipeline: inpainting pipeline """ pipe = SDPipeline() return pipe @torch.inference_mode() def make_image_controlnet(image: np.ndarray, mask_image: np.ndarray, controlnet_conditioning_image: np.ndarray, positive_prompt: str, negative_prompt: str, seed: int = 2356132) -> List[Image.Image]: """Method to make image using controlnet Args: image (np.ndarray): input image mask_image (np.ndarray): mask image controlnet_conditioning_image (np.ndarray): conditioning image positive_prompt (str): positive prompt string negative_prompt (str): negative prompt string seed (int, optional): seed. Defaults to 2356132. Returns: List[Image.Image]: list of generated images """ pipe = get_controlnet() flush() image = Image.fromarray(image).convert("RGB") controlnet_conditioning_image = Image.fromarray(controlnet_conditioning_image).convert("RGB")#.filter(ImageFilter.GaussianBlur(radius = 9)) mask_image = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGB") mask_image_postproc = convolution(mask_image) st.success(f"{pipe.queue_size} images in the queue, can take up to {(pipe.queue_size+1) * 10} seconds") generated_image = pipe( prompt=positive_prompt, negative_prompt=negative_prompt, num_inference_steps=20, strength=1.00, guidance_scale=7.0, generator=[torch.Generator(device="cuda").manual_seed(seed)], image=image, mask_image=mask_image, controlnet_conditioning_image=controlnet_conditioning_image, ).images[0] generated_image = postprocess_image_masking(generated_image, image, mask_image_postproc) return generated_image @torch.inference_mode() def make_inpainting(positive_prompt: str, image: Image, mask_image: np.ndarray, negative_prompt: str = "") -> List[Image.Image]: """Method to make inpainting Args: positive_prompt (str): positive prompt string image (Image): input image mask_image (np.ndarray): mask image negative_prompt (str, optional): negative prompt string. Defaults to "". Returns: List[Image.Image]: list of generated images """ pipe = get_inpainting_pipeline() mask_image_postproc = convolution(mask_image) flush() st.success(f"{pipe.queue_size} images in the queue, can take up to {(pipe.queue_size+1) * 10} seconds") generated_image = pipe(image=image, mask_image=Image.fromarray((mask_image * 255).astype(np.uint8)), prompt=positive_prompt, negative_prompt=negative_prompt, num_inference_steps=20, height=HEIGHT, width=WIDTH, **common_parameters ).images[0] generated_image = postprocess_image_masking(generated_image, image, mask_image_postproc) return image_ @torch.inference_mode() @torch.autocast('cuda') def segment_image(image: Image) -> Image: """Method to segment image Args: image (Image): input image Returns: Image: segmented image """ image_processor, image_segmentor = get_segmentation_pipeline() pixel_values = image_processor(image, return_tensors="pt").pixel_values with torch.no_grad(): outputs = image_segmentor(pixel_values) seg = image_processor.post_process_semantic_segmentation( outputs, target_sizes=[image.size[::-1]])[0] color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) palette = np.array(ade_palette()) for label, color in enumerate(palette): color_seg[seg == label, :] = color color_seg = color_seg.astype(np.uint8) seg_image = Image.fromarray(color_seg).convert('RGB') return seg_image