File size: 5,137 Bytes
b0369c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import os
import torch
import random
import numpy as np
import datetime

from PIL import Image
from diffusers import LMSDiscreteScheduler
from tqdm.auto import tqdm
from torch import autocast
from difflib import SequenceMatcher

import DirectedDiffusion


@torch.no_grad()
def stablediffusion(
    model_bundle,
    attn_editor_bundle={},
    device="cuda",
    prompt="",
    steps=50,
    seed=None,
    width=512,
    height=512,
    t_start=0,
    guidance_scale=7.5,
    init_latents=None,
    is_save_attn=False,
    is_save_recons=False,
    folder = "./",
):

    # neural networks
    unet = model_bundle["unet"]
    vae = model_bundle["vae"]
    clip_tokenizer = model_bundle["clip_tokenizer"]
    clip = model_bundle["clip_text_model"]
    # attn editor bundle, our stuff
    num_affected_steps = int(attn_editor_bundle.get("num_affected_steps") or 0)
    if not num_affected_steps:
        print("Not using attn editor")
    else:
        print("Using attn editor")
    DirectedDiffusion.AttnCore.init_attention_edit(
        unet,
        tokens=attn_editor_bundle.get("edit_index") or [],
        rios=attn_editor_bundle.get("roi") or [],
        noise_scale=attn_editor_bundle.get("noise_scale") or [],
        length_prompt=len(prompt.split(" ")),
        num_trailing_attn=attn_editor_bundle.get("num_trailing_attn") or [],
    )

    # Change size to multiple of 64 to prevent size mismatches inside model
    width = width - width % 64
    height = height - height % 64
    # If seed is None, randomly select seed from 0 to 2^32-1
    if seed is None:
        seed = random.randrange(2 ** 32 - 1)
    generator = torch.cuda.manual_seed(seed)
    # Set inference timesteps to scheduler
    scheduler = LMSDiscreteScheduler(
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        num_train_timesteps=1000,
    )
    scheduler.set_timesteps(steps)
    scheduler.timesteps = scheduler.timesteps.half().cuda()

    noise_weight = LMSDiscreteScheduler(
        beta_start=0.00085,
        beta_end=0.012,
        beta_schedule="scaled_linear",
        num_train_timesteps=10,
    )
    noise_weight.set_timesteps(num_affected_steps)
    # if num_affected_steps:
    #     noise_weight.set_timesteps(num_affected_steps)
    #     noise_weight.timesteps /= torch.max(noise_weight.timesteps)

    init_latent = torch.zeros(
        (1, unet.in_channels, height // 8, width // 8), device=device
    )
    t_start = t_start
    # Generate random normal noise
    noise = torch.randn(init_latent.shape, generator=generator, device=device)
    # latent = noise * scheduler.init_noise_sigma
    latent = scheduler.add_noise(
        init_latent,
        noise,
        torch.tensor(
            [scheduler.timesteps[t_start]], device=device, dtype=torch.float16
        ),
    ).to(device)



    current_time = datetime.datetime.now()
    current_time = current_time.strftime("%y%m%d-%H%M%S")
    folder = os.path.join(folder, current_time+"_internal")
    if not os.path.exists(folder) and (is_save_attn or is_save_recons):
        os.makedirs(folder)
    # Process clip
    with autocast(device):
        embeds_uncond = DirectedDiffusion.AttnEditorUtils.get_embeds(
            "", clip, clip_tokenizer
        )
        embeds_cond = DirectedDiffusion.AttnEditorUtils.get_embeds(
            prompt, clip, clip_tokenizer
        )
        timesteps = scheduler.timesteps[t_start:]
        for i, t in tqdm(enumerate(timesteps), total=len(timesteps)):
            t_index = t
            latent_model_input = latent
            latent_model_input = scheduler.scale_model_input(
                latent_model_input, t
            ).half()
            noise_pred_uncond = unet(
                latent_model_input, t, encoder_hidden_states=embeds_uncond
            ).sample

            if i < num_affected_steps:
                DirectedDiffusion.AttnEditorUtils.use_add_noise(
                    unet, noise_weight.timesteps[i]
                )
                DirectedDiffusion.AttnEditorUtils.use_edited_attention(unet)
                noise_pred_cond = unet(
                    latent_model_input, t, encoder_hidden_states=embeds_cond
                ).sample

            else:
                noise_pred_cond = unet(
                    latent_model_input, t, encoder_hidden_states=embeds_cond
                ).sample

            delta = noise_pred_cond - noise_pred_uncond
            # Perform guidance
            noise_pred = noise_pred_uncond + guidance_scale * delta
            latent = scheduler.step(noise_pred, t_index, latent).prev_sample

            if is_save_attn:
                filepath = os.path.join(folder, "ca.{:04d}.jpg".format(i))
                DirectedDiffusion.Plotter.plot_activation(filepath, unet, prompt, clip_tokenizer)
            if is_save_recons:
                filepath = os.path.join(folder, "recons.{:04d}.jpg".format(i))
                recons = DirectedDiffusion.AttnEditorUtils.get_image_from_latent(vae, latent)
                recons.save(filepath)
    return DirectedDiffusion.AttnEditorUtils.get_image_from_latent(vae, latent)