File size: 7,310 Bytes
89023a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# %%
# an example script of how to do outpainting with diffusers img2img pipeline
# should be compatible with any stable diffusion model
# (only tested with runwayml/stable-diffusion-v1-5)

from typing import Callable, List, Optional, Union
from PIL import Image
import PIL
import numpy as np
import torch

from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess

pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    revision="fp16",
    torch_dtype=torch.float16,
)

pipe.set_use_memory_efficient_attention_xformers(True)
pipe.to("cuda")
# %%
# load the image, extract the mask
rgba = Image.open('primed_image_with_alpha_channel.png')
mask_full = np.array(rgba)[:, :, 3] == 0
rgb = rgba.convert('RGB')
# %%

# resize/convert the mask to the right size
# for 512x512, the mask should be 1x4x64x64
hw = np.array(mask_full.shape)
h, w = (hw - hw % 32) // 8
mask_image = Image.fromarray(mask_full).resize((w, h), Image.NEAREST)
mask = (np.array(mask_image) == 0)[None, None]
mask = np.concatenate([mask]*4, axis=1)
mask = torch.from_numpy(mask).to('cuda')
mask.shape

# %%


@torch.no_grad()
def outpaint(
    self: StableDiffusionImg2ImgPipeline,
    prompt: Union[str, List[str]] = None,
    image: Union[torch.FloatTensor, PIL.Image.Image] = None,
    strength: float = 0.8,
    num_inference_steps: Optional[int] = 50,
    guidance_scale: Optional[float] = 7.5,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_images_per_prompt: Optional[int] = 1,
    eta: Optional[float] = 0.0,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "pil",
    return_dict: bool = True,
    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
    callback_steps: Optional[int] = 1,
    **kwargs,
):
    r"""
    copy of the original img2img pipeline's __call__()
    https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

    Changes are marked with <EDIT> and </EDIT>
    """
    # message = "Please use `image` instead of `init_image`."
    # init_image = deprecate("init_image", "0.14.0", message, take_from=kwargs)
    # image = init_image or image

    # 1. Check inputs. Raise error if not correct
    self.check_inputs(prompt, strength, callback_steps,
                      negative_prompt, prompt_embeds, negative_prompt_embeds)

    # 2. Define call parameters
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]
    device = self._execution_device
    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    do_classifier_free_guidance = guidance_scale > 1.0

    # 3. Encode input prompt
    prompt_embeds = self._encode_prompt(
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
    )

    # 4. Preprocess image
    image = preprocess(image)

    # 5. set timesteps
    self.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps, num_inference_steps = self.get_timesteps(
        num_inference_steps, strength, device)
    latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)

    # 6. Prepare latent variables
    latents = self.prepare_latents(
        image, latent_timestep, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator
    )

    # <EDIT>
    # store the encoded version of the original image to overwrite
    # what the UNET generates "underneath" our image on each step
    encoded_original = (self.vae.config.scaling_factor *
                  self.vae.encode(
                      image.to(latents.device, latents.dtype)
                  ).latent_dist.mean)
    # </EDIT>

    # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

    # 8. Denoising loop
    num_warmup_steps = len(timesteps) - \
        num_inference_steps * self.scheduler.order
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            # expand the latents if we are doing classifier free guidance
            latent_model_input = torch.cat(
                [latents] * 2) if do_classifier_free_guidance else latents
            latent_model_input = self.scheduler.scale_model_input(
                latent_model_input, t)

            # predict the noise residual
            noise_pred = self.unet(latent_model_input, t,
                                   encoder_hidden_states=prompt_embeds).sample

            # perform guidance
            if do_classifier_free_guidance:
                noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
                noise_pred = noise_pred_uncond + guidance_scale * \
                    (noise_pred_text - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(
                noise_pred, t, latents, **extra_step_kwargs).prev_sample

            # <EDIT> paste unmasked regions from the original image
            noise = torch.randn(
                encoded_original.shape, generator=generator, device=device)
            noised_encoded_original = self.scheduler.add_noise(
                encoded_original, noise, t).to(noise_pred.device, noise_pred.dtype)
            latents[mask] = noised_encoded_original[mask]
            # </EDIT>

            # call the callback, if provided
            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
                progress_bar.update()
                if callback is not None and i % callback_steps == 0:
                    callback(i, t, latents)

    # 9. Post-processing
    image = self.decode_latents(latents)

    # 10. Run safety checker
    image, has_nsfw_concept = self.run_safety_checker(
        image, device, prompt_embeds.dtype)

    # 11. Convert to PIL
    if output_type == "pil":
        image = self.numpy_to_pil(image)

    if not return_dict:
        return (image, has_nsfw_concept)

    return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)


# %%
image = outpaint(
    pipe,
    image=rgb,
    prompt="forest in the style of Tim Hildebrandt",
    strength=0.5,
    num_inference_steps=50,
    guidance_scale=7.5,
).images[0]
image

# %%
# the vae does lossy encoding, we could get better quality if we pasted the original image into our result.
# this may yield visible edges