Spaces:
Running
Running
"""This file contains methods for inference and image generation.""" | |
import logging | |
from typing import List, Tuple, Dict | |
import streamlit as st | |
import torch | |
import gc | |
import time | |
import numpy as np | |
from PIL import Image | |
from time import perf_counter | |
from contextlib import contextmanager | |
from scipy.signal import fftconvolve | |
from PIL import ImageFilter | |
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation | |
from diffusers import ControlNetModel, UniPCMultistepScheduler | |
from diffusers import StableDiffusionInpaintPipeline | |
from config import WIDTH, HEIGHT | |
from palette import ade_palette | |
from stable_diffusion_controlnet_inpaint_img2img import StableDiffusionControlNetInpaintImg2ImgPipeline | |
LOGGING = logging.getLogger(__name__) | |
def flush(): | |
gc.collect() | |
torch.cuda.empty_cache() | |
class ControlNetPipeline: | |
def __init__(self): | |
self.in_use = False | |
self.controlnet = ControlNetModel.from_pretrained( | |
"BertChristiaens/controlnet-seg-room", torch_dtype=torch.float16) | |
self.pipe = StableDiffusionControlNetInpaintImg2ImgPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
controlnet=self.controlnet, | |
safety_checker=None, | |
torch_dtype=torch.float16 | |
) | |
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config) | |
self.pipe.enable_xformers_memory_efficient_attention() | |
self.pipe = self.pipe.to("cuda") | |
self.waiting_queue = [] | |
self.count = 0 | |
def queue_size(self): | |
return len(self.waiting_queue) | |
def __call__(self, **kwargs): | |
self.count += 1 | |
number = self.count | |
self.waiting_queue.append(number) | |
# wait until the next number in the queue is the current number | |
while self.waiting_queue[0] != number: | |
print(f"Wait for your turn {number} in queue {self.waiting_queue}") | |
time.sleep(0.5) | |
pass | |
# it's your turn, so remove the number from the queue | |
# and call the function | |
print("It's the turn of", self.count) | |
results = self.pipe(**kwargs) | |
self.waiting_queue.pop(0) | |
flush() | |
return results | |
class SDPipeline: | |
def __init__(self): | |
self.pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-inpainting", | |
torch_dtype=torch.float16, | |
safety_checker=None, | |
) | |
self.pipe.enable_xformers_memory_efficient_attention() | |
self.pipe = self.pipe.to("cuda") | |
self.waiting_queue = [] | |
self.count = 0 | |
def queue_size(self): | |
return len(self.waiting_queue) | |
def __call__(self, **kwargs): | |
self.count += 1 | |
number = self.count | |
self.waiting_queue.append(number) | |
# wait until the next number in the queue is the current number | |
while self.waiting_queue[0] != number: | |
print(f"Wait for your turn {number} in queue {self.waiting_queue}") | |
time.sleep(0.5) | |
pass | |
# it's your turn, so remove the number from the queue | |
# and call the function | |
print("It's the turn of", self.count) | |
results = self.pipe(**kwargs) | |
self.waiting_queue.pop(0) | |
flush() | |
return results | |
def convolution(mask: Image.Image, size=9) -> Image: | |
"""Method to blur the mask | |
Args: | |
mask (Image): masking image | |
size (int, optional): size of the blur. Defaults to 9. | |
Returns: | |
Image: blurred mask | |
""" | |
mask = np.array(mask.convert("L")) | |
conv = np.ones((size, size)) / size**2 | |
mask_blended = fftconvolve(mask, conv, 'same') | |
mask_blended = mask_blended.astype(np.uint8).copy() | |
border = size | |
# replace borders with original values | |
mask_blended[:border, :] = mask[:border, :] | |
mask_blended[-border:, :] = mask[-border:, :] | |
mask_blended[:, :border] = mask[:, :border] | |
mask_blended[:, -border:] = mask[:, -border:] | |
return Image.fromarray(mask_blended).convert("L") | |
def postprocess_image_masking(inpainted: Image, image: Image, mask: Image) -> Image: | |
"""Method to postprocess the inpainted image | |
Args: | |
inpainted (Image): inpainted image | |
image (Image): original image | |
mask (Image): mask | |
Returns: | |
Image: inpainted image | |
""" | |
final_inpainted = Image.composite(inpainted.convert("RGBA"), image.convert("RGBA"), mask) | |
return final_inpainted.convert("RGB") | |
def get_controlnet() -> ControlNetModel: | |
"""Method to load the controlnet model | |
Returns: | |
ControlNetModel: controlnet model | |
""" | |
pipe = ControlNetPipeline() | |
return pipe | |
def get_segmentation_pipeline() -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: | |
"""Method to load the segmentation pipeline | |
Returns: | |
Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline | |
""" | |
image_processor = AutoImageProcessor.from_pretrained("openmmlab/upernet-convnext-small") | |
image_segmentor = UperNetForSemanticSegmentation.from_pretrained( | |
"openmmlab/upernet-convnext-small") | |
return image_processor, image_segmentor | |
def get_inpainting_pipeline() -> StableDiffusionInpaintPipeline: | |
"""Method to load the inpainting pipeline | |
Returns: | |
StableDiffusionInpaintPipeline: inpainting pipeline | |
""" | |
pipe = SDPipeline() | |
return pipe | |
def make_image_controlnet(image: np.ndarray, | |
mask_image: np.ndarray, | |
controlnet_conditioning_image: np.ndarray, | |
positive_prompt: str, negative_prompt: str, | |
seed: int = 2356132) -> List[Image.Image]: | |
"""Method to make image using controlnet | |
Args: | |
image (np.ndarray): input image | |
mask_image (np.ndarray): mask image | |
controlnet_conditioning_image (np.ndarray): conditioning image | |
positive_prompt (str): positive prompt string | |
negative_prompt (str): negative prompt string | |
seed (int, optional): seed. Defaults to 2356132. | |
Returns: | |
List[Image.Image]: list of generated images | |
""" | |
pipe = get_controlnet() | |
flush() | |
image = Image.fromarray(image).convert("RGB") | |
controlnet_conditioning_image = Image.fromarray(controlnet_conditioning_image).convert("RGB")#.filter(ImageFilter.GaussianBlur(radius = 9)) | |
mask_image = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGB") | |
mask_image_postproc = convolution(mask_image) | |
st.success(f"{pipe.queue_size} images in the queue, can take up to {(pipe.queue_size+1) * 10} seconds") | |
generated_image = pipe( | |
prompt=positive_prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=20, | |
strength=1.00, | |
guidance_scale=7.0, | |
generator=[torch.Generator(device="cuda").manual_seed(seed)], | |
image=image, | |
mask_image=mask_image, | |
controlnet_conditioning_image=controlnet_conditioning_image, | |
).images[0] | |
generated_image = postprocess_image_masking(generated_image, image, mask_image_postproc) | |
return generated_image | |
def make_inpainting(positive_prompt: str, | |
image: Image, | |
mask_image: np.ndarray, | |
negative_prompt: str = "") -> List[Image.Image]: | |
"""Method to make inpainting | |
Args: | |
positive_prompt (str): positive prompt string | |
image (Image): input image | |
mask_image (np.ndarray): mask image | |
negative_prompt (str, optional): negative prompt string. Defaults to "". | |
Returns: | |
List[Image.Image]: list of generated images | |
""" | |
pipe = get_inpainting_pipeline() | |
mask_image_postproc = convolution(mask_image) | |
flush() | |
st.success(f"{pipe.queue_size} images in the queue, can take up to {(pipe.queue_size+1) * 10} seconds") | |
generated_image = pipe(image=image, | |
mask_image=Image.fromarray((mask_image * 255).astype(np.uint8)), | |
prompt=positive_prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=20, | |
height=HEIGHT, | |
width=WIDTH, | |
).images[0] | |
generated_image = postprocess_image_masking(generated_image, image, mask_image_postproc) | |
return image_ | |
def segment_image(image: Image) -> Image: | |
"""Method to segment image | |
Args: | |
image (Image): input image | |
Returns: | |
Image: segmented image | |
""" | |
image_processor, image_segmentor = get_segmentation_pipeline() | |
pixel_values = image_processor(image, return_tensors="pt").pixel_values | |
with torch.no_grad(): | |
outputs = image_segmentor(pixel_values) | |
seg = image_processor.post_process_semantic_segmentation( | |
outputs, target_sizes=[image.size[::-1]])[0] | |
color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) | |
palette = np.array(ade_palette()) | |
for label, color in enumerate(palette): | |
color_seg[seg == label, :] = color | |
color_seg = color_seg.astype(np.uint8) | |
seg_image = Image.fromarray(color_seg).convert('RGB') | |
return seg_image |