File size: 22,392 Bytes
67e6974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4924c3
67e6974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4924c3
67e6974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4924c3
67e6974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4924c3
67e6974
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
from typing import Optional

import numpy as np
import torch
from diffusers import (
    AutoencoderKL,
    DDIMScheduler,
    SchedulerMixin,
    UNet2DConditionModel,
    UniPCMultistepScheduler,
)
from diffusers.models.attention_processor import AttnProcessor2_0
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer

from interpolation import (
    InnerInterpolatedAttnProcessor,
    OuterInterpolatedAttnProcessor,
    generate_beta_tensor,
    linear_interpolation,
    slerp,
    spherical_interpolation,
)


class InterpolationStableDiffusionPipeline:
    """
    Diffusion Pipeline that generates interpolated images
    """

    def __init__(
        self,
        repo_name: str = "CompVis/stable-diffusion-v1-4",
        scheduler_name: str = "ddim",
        frozen: bool = True,
        guidance_scale: float = 7.5,
        scheduler: Optional[SchedulerMixin] = None,
        cache_dir: Optional[str] = None,
    ):

        # Initialize the generator
        self.vae = AutoencoderKL.from_pretrained(
            repo_name, subfolder="vae", use_safetensors=True, cache_dir=cache_dir
        )
        self.tokenizer = CLIPTokenizer.from_pretrained(
            repo_name, subfolder="tokenizer", cache_dir=cache_dir
        )
        self.text_encoder = CLIPTextModel.from_pretrained(
            repo_name,
            subfolder="text_encoder",
            use_safetensors=True,
            cache_dir=cache_dir,
        )
        self.unet = UNet2DConditionModel.from_pretrained(
            repo_name, subfolder="unet", use_safetensors=True, cache_dir=cache_dir
        )

        # Initialize the scheduler
        if scheduler is not None:
            self.scheduler = scheduler
        elif scheduler_name == "ddim":
            self.scheduler = DDIMScheduler.from_pretrained(
                repo_name, subfolder="scheduler", cache_dir=cache_dir
            )
        elif scheduler_name == "unipc":
            self.scheduler = UniPCMultistepScheduler.from_pretrained(
                repo_name, subfolder="scheduler", cache_dir=cache_dir
            )
        else:
            raise ValueError(
                "Invalid scheduler name (ddim, unipc) and not specify scheduler."
            )

        # Setup device

        self.guidance_scale = guidance_scale  # Scale for classifier-free guidance

        if frozen:
            for param in self.unet.parameters():
                param.requires_grad = False

            for param in self.text_encoder.parameters():
                param.requires_grad = False

            for param in self.vae.parameters():
                param.requires_grad = False

    def to(self, *args, **kwargs):
        self.vae.to(*args, **kwargs)
        self.text_encoder.to(*args, **kwargs)
        self.unet.to(*args, **kwargs)

    def generate_latent(
        self, generator: Optional[torch.Generator] = None, torch_device: str = "cpu"
    ) -> torch.FloatTensor:
        """
        Generates a random latent tensor.

        Args:
            generator (Optional[torch.Generator], optional): Generator for random number generation. Defaults to None.
            torch_device (str, optional): Device to store the tensor. Defaults to "cpu".

        Returns:
            torch.FloatTensor: Random latent tensor.
        """
        channel = self.unet.config.in_channels
        height = self.unet.config.sample_size
        width = self.unet.config.sample_size
        if generator is None:
            latent = torch.randn(
                (1, channel, height, width),
                device=torch_device,
            )
        else:
            latent = torch.randn(
                (1, channel, height, width),
                generator=generator,
                device=torch_device,
            )
        return latent

    @torch.no_grad()
    def prompt_to_embedding(
        self, prompt: str, negative_prompt: str = ""
    ) -> torch.FloatTensor:
        """
        Prepare the text prompt for the diffusion process

        Args:
            prompt: str, text prompt
            negative_prompt: str, negative text prompt

        Returns:
            FloatTensor, text embeddings
        """

        text_input = self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )

        text_embeddings = self.text_encoder(text_input.input_ids.to(self.torch_device))[
            0
        ]

        uncond_input = self.tokenizer(
            negative_prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )
        uncond_embeddings = self.text_encoder(
            uncond_input.input_ids.to(self.torch_device)
        )[0]

        text_embeddings = torch.cat([text_embeddings, uncond_embeddings])
        return text_embeddings

    @torch.no_grad()
    def interpolate(
        self,
        latent_start: torch.FloatTensor,
        latent_end: torch.FloatTensor,
        prompt_start: str,
        prompt_end: str,
        guide_prompt: Optional[str] = None,
        negative_prompt: str = "",
        size: int = 7,
        num_inference_steps: int = 25,
        warmup_ratio: float = 0.5,
        early: str = "fused_outer",
        late: str = "self",
        alpha: Optional[float] = None,
        beta: Optional[float] = None,
        guidance_scale: Optional[float] = None,
    ) -> np.ndarray:
        """
        Interpolate between two generation

        Args:
            latent_start: FloatTensor, latent vector of the first image
            latent_end: FloatTensor, latent vector of the second image
            prompt_start: str, text prompt of the first image
            prompt_end: str, text prompt of the second image
            guide_prompt: str, text prompt for the interpolation
            negative_prompt: str, negative text prompt
            size: int, number of interpolations including starting and ending points
            num_inference_steps: int, number of inference steps in scheduler
            warmup_ratio: float, ratio of warmup steps
            early: str, warmup interpolation methods
            late: str, late interpolation methods
            alpha: float, alpha parameter for beta distribution
            beta: float, beta parameter for beta distribution
            guidance_scale: Optional[float], scale for classifier-free guidance
        Returns:
            Numpy array of interpolated images, shape (size, H, W, 3)
        """
        # Specify alpha and beta
        self.torch_device = self.unet.device
        if alpha is None:
            alpha = num_inference_steps
        if beta is None:
            beta = num_inference_steps
        if guidance_scale is None:
            guidance_scale = self.guidance_scale
        self.scheduler.set_timesteps(num_inference_steps)

        # Prepare interpolated latents and embeddings
        latents = spherical_interpolation(latent_start, latent_end, size)
        embs_start = self.prompt_to_embedding(prompt_start, negative_prompt)
        emb_start = embs_start[0:1]
        uncond_emb_start = embs_start[1:2]
        embs_end = self.prompt_to_embedding(prompt_end, negative_prompt)
        emb_end = embs_end[0:1]
        uncond_emb_end = embs_end[1:2]

        # Perform prompt guidance if it is specified
        if guide_prompt is not None:
            guide_embs = self.prompt_to_embedding(guide_prompt, negative_prompt)
            guide_emb = guide_embs[0:1]
            uncond_guide_emb = guide_embs[1:2]
            embs = torch.cat([emb_start] + [guide_emb] * (size - 2) + [emb_end], dim=0)
            uncond_embs = torch.cat(
                [uncond_emb_start] + [uncond_guide_emb] * (size - 2) + [uncond_emb_end],
                dim=0,
            )
        else:
            embs = linear_interpolation(emb_start, emb_end, size=size)
            uncond_embs = linear_interpolation(
                uncond_emb_start, uncond_emb_end, size=size
            )

        # Specify the interpolation methods
        pure_inner_attn_proc = InnerInterpolatedAttnProcessor(
            size=size,
            is_fused=False,
            alpha=alpha,
            beta=beta,
        )
        fused_inner_attn_proc = InnerInterpolatedAttnProcessor(
            size=size,
            is_fused=True,
            alpha=alpha,
            beta=beta,
        )
        pure_outer_attn_proc = OuterInterpolatedAttnProcessor(
            size=size,
            is_fused=False,
            alpha=alpha,
            beta=beta,
        )
        fused_outer_attn_proc = OuterInterpolatedAttnProcessor(
            size=size,
            is_fused=True,
            alpha=alpha,
            beta=beta,
        )
        self_attn_proc = AttnProcessor2_0()
        procs_dict = {
            "pure_inner": pure_inner_attn_proc,
            "fused_inner": fused_inner_attn_proc,
            "pure_outer": pure_outer_attn_proc,
            "fused_outer": fused_outer_attn_proc,
            "self": self_attn_proc,
        }

        # Denoising process
        i = 0
        warmup_step = int(num_inference_steps * warmup_ratio)
        for t in tqdm(self.scheduler.timesteps):
            i += 1
            latent_model_input = self.scheduler.scale_model_input(latents, timestep=t)
            with torch.no_grad():
                # Change attention module
                if i < warmup_step:
                    interpolate_attn_proc = procs_dict[early]
                else:
                    interpolate_attn_proc = procs_dict[late]
                self.unet.set_attn_processor(processor=interpolate_attn_proc)

                # Predict the noise residual
                noise_pred = self.unet(
                    latent_model_input, t, encoder_hidden_states=embs
                ).sample
                attn_proc = AttnProcessor2_0()
                self.unet.set_attn_processor(processor=attn_proc)
                noise_uncond = self.unet(
                    latent_model_input, t, encoder_hidden_states=uncond_embs
                ).sample
            # perform guidance
            noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        # Decode the images
        latents = 1 / 0.18215 * latents
        with torch.no_grad():
            image = self.vae.decode(latents).sample
        images = (image / 2 + 0.5).clamp(0, 1)
        images = (images.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()
        return images

    @torch.no_grad()
    def interpolate_save_gpu(
        self,
        latent_start: torch.FloatTensor,
        latent_end: torch.FloatTensor,
        prompt_start: str,
        prompt_end: str,
        guide_prompt: Optional[str] = None,
        negative_prompt: str = "",
        size: int = 7,
        num_inference_steps: int = 25,
        warmup_ratio: float = 0.5,
        early: str = "fused_outer",
        late: str = "self",
        alpha: Optional[float] = None,
        beta: Optional[float] = None,
        init: str = "linear",
        guidance_scale: Optional[float] = None,
    ) -> np.ndarray:
        """
        Interpolate between two generation

        Args:
            latent_start: FloatTensor, latent vector of the first image
            latent_end: FloatTensor, latent vector of the second image
            prompt_start: str, text prompt of the first image
            prompt_end: str, text prompt of the second image
            guide_prompt: str, text prompt for the interpolation
            negative_prompt: str, negative text prompt
            size: int, number of interpolations including starting and ending points
            num_inference_steps: int, number of inference steps in scheduler
            warmup_ratio: float, ratio of warmup steps
            early: str, warmup interpolation methods
            late: str, late interpolation methods
            alpha: float, alpha parameter for beta distribution
            beta: float, beta parameter for beta distribution
            init: str, interpolation initialization methods

        Returns:
            Numpy array of interpolated images, shape (size, H, W, 3)
        """
        self.torch_device = self.unet.device
        # Specify alpha and beta
        if alpha is None:
            alpha = num_inference_steps
        if beta is None:
            beta = num_inference_steps
        betas = generate_beta_tensor(size, alpha=alpha, beta=beta)
        final_images = None

        # Generate interpolated images one by one
        for i in range(size - 2):
            it = betas[i + 1].item()
            if init == "denoising":
                images = self.denoising_interpolate(
                    latent_start,
                    prompt_start,
                    prompt_end,
                    negative_prompt,
                    interpolated_ratio=it,
                    timesteps=num_inference_steps,
                )
            else:
                images = self.interpolate_single(
                    it,
                    latent_start,
                    latent_end,
                    prompt_start,
                    prompt_end,
                    guide_prompt=guide_prompt,
                    num_inference_steps=num_inference_steps,
                    warmup_ratio=warmup_ratio,
                    early=early,
                    late=late,
                    negative_prompt=negative_prompt,
                    init=init,
                    guidance_scale=guidance_scale,
                )
            if size == 3:
                return images
            if i == 0:
                final_images = images[:2]
            elif i == size - 3:
                final_images = np.concatenate([final_images, images[1:]], axis=0)
            else:
                final_images = np.concatenate([final_images, images[1:2]], axis=0)
        return final_images

    def interpolate_single(
        self,
        it,
        latent_start: torch.FloatTensor,
        latent_end: torch.FloatTensor,
        prompt_start: str,
        prompt_end: str,
        guide_prompt: str = None,
        negative_prompt: str = "",
        num_inference_steps: int = 25,
        warmup_ratio: float = 0.5,
        early: str = "fused_outer",
        late: str = "self",
        init="linear",
        guidance_scale: Optional[float] = None,
    ) -> np.ndarray:
        """
        Interpolates between two latent vectors and generates a sequence of images.

        Args:
            it (float): Interpolation factor between latent_start and latent_end.
            latent_start (torch.FloatTensor): Starting latent vector.
            latent_end (torch.FloatTensor): Ending latent vector.
            prompt_start (str): Starting prompt for text conditioning.
            prompt_end (str): Ending prompt for text conditioning.
            guide_prompt (str, optional): Guiding prompt for text conditioning. Defaults to None.
            negative_prompt (str, optional): Negative prompt for text conditioning. Defaults to "".
            num_inference_steps (int, optional): Number of inference steps. Defaults to 25.
            warmup_ratio (float, optional): Ratio of warm-up steps. Defaults to 0.5.
            early (str, optional): Early attention processing method. Defaults to "fused_outer".
            late (str, optional): Late attention processing method. Defaults to "self".
            init (str, optional): Initialization method for interpolation. Defaults to "linear".
            guidance_scale (Optional[float], optional): Scale for classifier-free guidance. Defaults to None.
        Returns:
            numpy.ndarray: Sequence of generated images.
        """
        self.torch_device = self.unet.device
        if guidance_scale is None:
            guidance_scale = self.guidance_scale

        # Prepare interpolated inputs
        self.scheduler.set_timesteps(num_inference_steps)

        embs_start = self.prompt_to_embedding(prompt_start, negative_prompt)
        emb_start = embs_start[0:1]
        uncond_emb_start = embs_start[1:2]
        embs_end = self.prompt_to_embedding(prompt_end, negative_prompt)
        emb_end = embs_end[0:1]
        uncond_emb_end = embs_end[1:2]

        latent_t = slerp(latent_start, latent_end, it)
        if guide_prompt is not None:
            embs_guide = self.prompt_to_embedding(guide_prompt, negative_prompt)
            emb_t = embs_guide[0:1]
        else:
            if init == "linear":
                emb_t = torch.lerp(emb_start, emb_end, it)
            else:
                emb_t = slerp(emb_start, emb_end, it)
        if init == "linear":
            uncond_emb_t = torch.lerp(uncond_emb_start, uncond_emb_end, it)
        else:
            uncond_emb_t = slerp(uncond_emb_start, uncond_emb_end, it)

        latents = torch.cat([latent_start, latent_t, latent_end], dim=0)
        embs = torch.cat([emb_start, emb_t, emb_end], dim=0)
        uncond_embs = torch.cat([uncond_emb_start, uncond_emb_t, uncond_emb_end], dim=0)

        # Specifiy the attention processors
        pure_inner_attn_proc = InnerInterpolatedAttnProcessor(
            t=it,
            is_fused=False,
        )
        fused_inner_attn_proc = InnerInterpolatedAttnProcessor(
            t=it,
            is_fused=True,
        )
        pure_outer_attn_proc = OuterInterpolatedAttnProcessor(
            t=it,
            is_fused=False,
        )
        fused_outer_attn_proc = OuterInterpolatedAttnProcessor(
            t=it,
            is_fused=True,
        )
        self_attn_proc = AttnProcessor2_0()
        procs_dict = {
            "pure_inner": pure_inner_attn_proc,
            "fused_inner": fused_inner_attn_proc,
            "pure_outer": pure_outer_attn_proc,
            "fused_outer": fused_outer_attn_proc,
            "self": self_attn_proc,
        }

        i = 0
        warmup_step = int(num_inference_steps * warmup_ratio)
        for t in tqdm(self.scheduler.timesteps):
            i += 1
            latent_model_input = self.scheduler.scale_model_input(latents, timestep=t)
            # predict the noise residual
            with torch.no_grad():
                # Warmup
                if i < warmup_step:
                    interpolate_attn_proc = procs_dict[early]
                else:
                    interpolate_attn_proc = procs_dict[late]
                self.unet.set_attn_processor(processor=interpolate_attn_proc)
                # predict the noise residual
                noise_pred = self.unet(
                    latent_model_input, t, encoder_hidden_states=embs
                ).sample
                attn_proc = AttnProcessor2_0()
                self.unet.set_attn_processor(processor=attn_proc)
                noise_uncond = self.unet(
                    latent_model_input, t, encoder_hidden_states=uncond_embs
                ).sample
            # perform guidance
            noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample

        # Decode the images
        latents = 1 / 0.18215 * latents
        with torch.no_grad():
            image = self.vae.decode(latents).sample
        images = (image / 2 + 0.5).clamp(0, 1)
        images = (images.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()
        return images

    def denoising_interpolate(
        self,
        latents: torch.FloatTensor,
        text_1: str,
        text_2: str,
        negative_prompt: str = "",
        interpolated_ratio: float = 1,
        timesteps: int = 25,
    ) -> np.ndarray:
        """
        Performs denoising interpolation on the given latents.

        Args:
            latents (torch.Tensor): The input latents.
            text_1 (str): The first text prompt.
            text_2 (str): The second text prompt.
            negative_prompt (str, optional): The negative text prompt. Defaults to "".
            interpolated_ratio (int, optional): The ratio of interpolation between text_1 and text_2. Defaults to 1.
            timesteps (int, optional): The number of timesteps for diffusion. Defaults to 25.

        Returns:
            numpy.ndarray: The interpolated images.
        """
        self.unet.set_attn_processor(processor=AttnProcessor2_0())
        start_emb = self.prompt_to_embedding(text_1)
        end_emb = self.prompt_to_embedding(text_2)
        neg_emb = self.prompt_to_embedding(negative_prompt)
        uncond_emb = neg_emb[0:1]
        emb_1 = start_emb[0:1]
        emb_2 = end_emb[0:1]
        self.scheduler.set_timesteps(timesteps)
        i = 0
        for t in tqdm(self.scheduler.timesteps):
            i += 1
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = self.scheduler.scale_model_input(latents, timestep=t)
            # predict the noise residual
            with torch.no_grad():
                if i < timesteps * interpolated_ratio:
                    noise_pred = self.unet(
                        latent_model_input, t, encoder_hidden_states=emb_1
                    ).sample
                else:
                    noise_pred = self.unet(
                        latent_model_input, t, encoder_hidden_states=emb_2
                    ).sample
                noise_uncond = self.unet(
                    latent_model_input, t, encoder_hidden_states=uncond_emb
                ).sample
            # perform guidance
            noise_pred = noise_uncond + self.guidance_scale * (
                noise_pred - noise_uncond
            )
            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample
        latents = 1 / 0.18215 * latents
        with torch.no_grad():
            image = self.vae.decode(latents).sample
        images = (image / 2 + 0.5).clamp(0, 1)
        images = (images.permute(0, 2, 3, 1) * 255).to(torch.uint8).cpu().numpy()
        return images