Spaces:
Build error
Build error
File size: 5,065 Bytes
5b2ab1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from typing import Any, List, Optional, Tuple, Union
import itertools
from PIL import Image
import numpy as np
import torch
from controlnet_aux import MLSDdetector, PidiNetDetector, HEDdetector
from diffusers import (
ControlNetModel,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler,
)
MODEL_DICT = {
"mlsd": {
"name": "lllyasviel/Annotators",
"detector": MLSDdetector,
"model": "lllyasviel/control_v11p_sd15_mlsd",
},
"soft_edge": {
"name": "lllyasviel/Annotators",
"detector": PidiNetDetector,
"model": "lllyasviel/control_v11p_sd15_softedge",
},
"hed": {
"name": "lllyasviel/Annotators",
"detector": HEDdetector,
"model": "lllyasviel/sd-controlnet-hed",
},
"scribble": {
"name": "lllyasviel/Annotators",
"detector": HEDdetector,
"model": "lllyasviel/control_v11p_sd15_scribble",
},
}
class StableDiffusionControlNet:
"""ControlNet pipeline for generating images from prompts.
Args:
control_model_name (str):
Name of the controlnet processor.
sd_model_name (str):
Name of the StableDiffusion model.
"""
def __init__(
self,
control_model_name: str,
sd_model_name: Optional[str] = "runwayml/stable-diffusion-v1-5",
) -> None:
self.processor = MODEL_DICT[control_model_name]["detector"].from_pretrained(
MODEL_DICT[control_model_name]["name"]
)
self.pipe = self.create_pipe(
sd_model_name=sd_model_name, control_model_name=control_model_name
)
def _repeat(self, items: List[Any], n: int) -> List[Any]:
"""Repeat items in a list n times.
Args:
items (List[Any]): List of items to be repeated.
n (int): Number of repetitions.
Returns:
List[Any]: List of repeated items.
"""
return list(
itertools.chain.from_iterable(itertools.repeat(item, n) for item in items)
)
def generate_control_images(self, images: List[Image.Image]) -> List[Image.Image]:
"""Generate control images from input images.
Args:
images (List[Image.Image]): Input images.
Returns:
List[Image.Image]: Control images.
"""
return [self.processor(image) for image in images]
def create_pipe(
self, sd_model_name: str, control_model_name: str
) -> StableDiffusionControlNetPipeline:
"""Create a StableDiffusionControlNetPipeline.
Args:
sd_model_name (str): StableDiffusion model name.
control_model_name (str): Name of the ControlNet module.
Returns:
StableDiffusionControlNetPipeline
"""
controlnet = ControlNetModel.from_pretrained(
MODEL_DICT[control_model_name]["model"], torch_dtype=torch.float16
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
sd_model_name, controlnet=controlnet, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
return pipe
def process(
self,
images: List[Image.Image],
prompts: List[str],
negative_prompt: Optional[str] = None,
n_outputs: Optional[int] = 1,
num_inference_steps: Optional[int] = 30,
) -> List[List[Image.Image]]:
"""Generate images from `prompts` using `control_images` and `negative_prompt`.
Args:
images (List[Image.Image]): Input images.
prompts (List[str]): List of prompts.
negative_prompt (Optional[str], optional): Negative prompt. Defaults to None.
n_outputs (Optional[int], optional): Number of generated outputs. Defaults to 1.
num_inference_steps (Optional[int], optional): Number of inference iterations. Defaults to 30.
Returns:
List[List[Image.Image]]
"""
control_images = self.generate_control_images(images)
assert len(prompts) == len(
control_images
), "Number of prompts and input images must be equal."
if n_outputs > 1:
prompts = self._repeat(prompts, n=n_outputs)
control_images = self._repeat(control_images, n=n_outputs)
generator = [
torch.Generator(device="cuda").manual_seed(int(i))
for i in np.random.randint(len(prompts), size=len(prompts))
]
output = self.pipe(
prompts,
image=control_images,
negative_prompt=[negative_prompt] * len(prompts),
num_inference_steps=num_inference_steps,
generator=generator,
)
output_images = [
output.images[idx * n_outputs : (idx + 1) * n_outputs]
for idx in range(len(images))
]
return output_images
|