File size: 5,065 Bytes
5b2ab1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from typing import Any, List, Optional, Tuple, Union
import itertools
from PIL import Image
import numpy as np
import torch

from controlnet_aux import MLSDdetector, PidiNetDetector, HEDdetector
from diffusers import (
    ControlNetModel,
    StableDiffusionControlNetPipeline,
    UniPCMultistepScheduler,
)


MODEL_DICT = {
    "mlsd": {
        "name": "lllyasviel/Annotators",
        "detector": MLSDdetector,
        "model": "lllyasviel/control_v11p_sd15_mlsd",
    },
    "soft_edge": {
        "name": "lllyasviel/Annotators",
        "detector": PidiNetDetector,
        "model": "lllyasviel/control_v11p_sd15_softedge",
    },
    "hed": {
        "name": "lllyasviel/Annotators",
        "detector": HEDdetector,
        "model": "lllyasviel/sd-controlnet-hed",
    },
    "scribble": {
        "name": "lllyasviel/Annotators",
        "detector": HEDdetector,
        "model": "lllyasviel/control_v11p_sd15_scribble",
    },
}


class StableDiffusionControlNet:
    """ControlNet pipeline for generating images from prompts.

    Args:
        control_model_name (str):
            Name of the controlnet processor.
        sd_model_name (str):
            Name of the StableDiffusion model.
    """

    def __init__(
        self,
        control_model_name: str,
        sd_model_name: Optional[str] = "runwayml/stable-diffusion-v1-5",
    ) -> None:
        self.processor = MODEL_DICT[control_model_name]["detector"].from_pretrained(
            MODEL_DICT[control_model_name]["name"]
        )
        self.pipe = self.create_pipe(
            sd_model_name=sd_model_name, control_model_name=control_model_name
        )

    def _repeat(self, items: List[Any], n: int) -> List[Any]:
        """Repeat items in a list n times.

        Args:
            items (List[Any]): List of items to be repeated.
            n (int): Number of repetitions.

        Returns:
            List[Any]: List of repeated items.
        """
        return list(
            itertools.chain.from_iterable(itertools.repeat(item, n) for item in items)
        )

    def generate_control_images(self, images: List[Image.Image]) -> List[Image.Image]:
        """Generate control images from input images.

        Args:
            images (List[Image.Image]): Input images.

        Returns:
            List[Image.Image]: Control images.
        """
        return [self.processor(image) for image in images]

    def create_pipe(
        self, sd_model_name: str, control_model_name: str
    ) -> StableDiffusionControlNetPipeline:
        """Create a StableDiffusionControlNetPipeline.

        Args:
            sd_model_name (str): StableDiffusion model name.
            control_model_name (str): Name of the ControlNet module.

        Returns:
            StableDiffusionControlNetPipeline
        """
        controlnet = ControlNetModel.from_pretrained(
            MODEL_DICT[control_model_name]["model"], torch_dtype=torch.float16
        )
        pipe = StableDiffusionControlNetPipeline.from_pretrained(
            sd_model_name, controlnet=controlnet, torch_dtype=torch.float16
        )

        pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
        pipe.enable_model_cpu_offload()
        pipe.enable_xformers_memory_efficient_attention()

        return pipe

    def process(
        self,
        images: List[Image.Image],
        prompts: List[str],
        negative_prompt: Optional[str] = None,
        n_outputs: Optional[int] = 1,
        num_inference_steps: Optional[int] = 30,
    ) -> List[List[Image.Image]]:
        """Generate images from `prompts` using `control_images` and `negative_prompt`.

        Args:
            images (List[Image.Image]): Input images.
            prompts (List[str]): List of prompts.
            negative_prompt (Optional[str], optional): Negative prompt. Defaults to None.
            n_outputs (Optional[int], optional): Number of generated outputs. Defaults to 1.
            num_inference_steps (Optional[int], optional): Number of inference iterations. Defaults to 30.

        Returns:
            List[List[Image.Image]]
        """

        control_images = self.generate_control_images(images)

        assert len(prompts) == len(
            control_images
        ), "Number of prompts and input images must be equal."

        if n_outputs > 1:
            prompts = self._repeat(prompts, n=n_outputs)
            control_images = self._repeat(control_images, n=n_outputs)

        generator = [
            torch.Generator(device="cuda").manual_seed(int(i))
            for i in np.random.randint(len(prompts), size=len(prompts))
        ]

        output = self.pipe(
            prompts,
            image=control_images,
            negative_prompt=[negative_prompt] * len(prompts),
            num_inference_steps=num_inference_steps,
            generator=generator,
        )

        output_images = [
            output.images[idx * n_outputs : (idx + 1) * n_outputs]
            for idx in range(len(images))
        ]

        return output_images