File size: 3,898 Bytes
7d4afe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from functools import partial

import torch
from diffusers import StableDiffusionXLKDiffusionPipeline
from k_diffusion.sampling import get_sigmas_polyexponential
from k_diffusion.sampling import sample_dpmpp_2m_sde


def set_timesteps_polyexponential(self, orig_sigmas, num_inference_steps, device=None):
    self.num_inference_steps = num_inference_steps

    self.sigmas = get_sigmas_polyexponential(
        num_inference_steps + 1,
        sigma_min=orig_sigmas[-2],
        sigma_max=orig_sigmas[0],
        rho=0.666666,
        device=device or "cpu",
    )
    self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas.new_zeros([1])])


def load_model(model_id="KBlueLeaf/Kohaku-XL-Epsilon", device="cuda"):
    pipe: StableDiffusionXLKDiffusionPipeline
    pipe = StableDiffusionXLKDiffusionPipeline.from_pretrained(
        model_id, torch_dtype=torch.float16
    ).to(device)
    pipe.scheduler.set_timesteps = partial(
        set_timesteps_polyexponential, pipe.scheduler, pipe.scheduler.sigmas
    )
    pipe.sampler = partial(sample_dpmpp_2m_sde, eta=0.35, solver_type="heun")
    return pipe


def encode_prompts(pipe: StableDiffusionXLKDiffusionPipeline, prompt, neg_prompt):
    max_length = pipe.tokenizer.model_max_length

    input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids.to("cuda")
    input_ids2 = pipe.tokenizer_2(prompt, return_tensors="pt").input_ids.to("cuda")

    negative_ids = pipe.tokenizer(
        neg_prompt,
        truncation=False,
        padding="max_length",
        max_length=input_ids.shape[-1],
        return_tensors="pt",
    ).input_ids.to("cuda")
    negative_ids2 = pipe.tokenizer_2(
        neg_prompt,
        truncation=False,
        padding="max_length",
        max_length=input_ids.shape[-1],
        return_tensors="pt",
    ).input_ids.to("cuda")

    if negative_ids.size() > input_ids.size():
        input_ids = pipe.tokenizer(
            prompt,
            truncation=False,
            padding="max_length",
            max_length=negative_ids.shape[-1],
            return_tensors="pt",
        ).input_ids.to("cuda")
        input_ids2 = pipe.tokenizer_2(
            prompt,
            truncation=False,
            padding="max_length",
            max_length=negative_ids.shape[-1],
            return_tensors="pt",
        ).input_ids.to("cuda")

    concat_embeds = []
    neg_embeds = []
    for i in range(0, input_ids.shape[-1], max_length):
        concat_embeds.append(pipe.text_encoder(input_ids[:, i : i + max_length])[0])
        neg_embeds.append(pipe.text_encoder(negative_ids[:, i : i + max_length])[0])

    concat_embeds2 = []
    neg_embeds2 = []
    pooled_embeds2 = []
    neg_pooled_embeds2 = []
    for i in range(0, input_ids.shape[-1], max_length):
        hidden_states = pipe.text_encoder_2(
            input_ids2[:, i : i + max_length], output_hidden_states=True
        )
        concat_embeds2.append(hidden_states.hidden_states[-2])
        pooled_embeds2.append(hidden_states[0])

        hidden_states = pipe.text_encoder_2(
            negative_ids2[:, i : i + max_length], output_hidden_states=True
        )
        neg_embeds2.append(hidden_states.hidden_states[-2])
        neg_pooled_embeds2.append(hidden_states[0])

    prompt_embeds = torch.cat(concat_embeds, dim=1)
    negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
    prompt_embeds2 = torch.cat(concat_embeds2, dim=1)
    negative_prompt_embeds2 = torch.cat(neg_embeds2, dim=1)
    prompt_embeds = torch.cat([prompt_embeds, prompt_embeds2], dim=-1)
    negative_prompt_embeds = torch.cat(
        [negative_prompt_embeds, negative_prompt_embeds2], dim=-1
    )

    pooled_embeds2 = torch.mean(torch.stack(pooled_embeds2, dim=0), dim=0)
    neg_pooled_embeds2 = torch.mean(torch.stack(neg_pooled_embeds2, dim=0), dim=0)

    return prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2