File size: 4,811 Bytes
ae18532
 
 
 
 
 
b00d4fe
163a3a9
ae18532
 
0d34381
b00d4fe
 
 
 
ae18532
b00d4fe
ae18532
 
 
0d34381
ae18532
 
0d34381
ae18532
 
 
 
 
0177258
b00d4fe
ae18532
b00d4fe
 
 
 
 
 
6ad0411
 
 
0d34381
 
 
 
ae18532
67ca03a
ae18532
 
 
 
 
 
 
163a3a9
ae18532
0177258
67ca03a
6ad0411
80551a9
67ca03a
b00d4fe
6ad0411
 
 
b00d4fe
ae18532
6ad0411
 
ae18532
 
6ad0411
 
ae18532
 
6ad0411
 
ae18532
 
6ad0411
 
ae18532
 
b00d4fe
 
 
 
 
ae18532
 
 
0d34381
ae18532
6ad0411
0d34381
 
ae18532
163a3a9
ae18532
b00d4fe
 
 
ae18532
 
67ca03a
ae18532
b00d4fe
ae18532
 
 
 
 
 
 
 
0d34381
 
 
 
 
 
 
 
 
 
 
 
b00d4fe
ae18532
b00d4fe
 
 
 
 
 
 
ae18532
0d34381
 
b00d4fe
0d34381
b00d4fe
 
 
0d34381
b00d4fe
6ad0411
80551a9
af35186
 
0d34381
 
ae18532
af35186
0d34381
ae18532
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import time
from datetime import datetime

import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import PromptParser
from gradio import Error, Info, Progress
from spaces import GPU

from .loader import Loader
from .logger import Logger
from .utils import cuda_collect, get_output_types, timer


@GPU
def generate(
    positive_prompt="",
    negative_prompt="",
    seed=None,
    model="stabilityai/stable-diffusion-xl-base-1.0",
    scheduler="Euler",
    width=1024,
    height=1024,
    guidance_scale=6.0,
    inference_steps=40,
    deepcache=1,
    scale=1,
    num_images=1,
    use_karras=False,
    use_refiner=False,
    progress=Progress(track_tqdm=True),
):
    if not torch.cuda.is_available():
        raise Error("CUDA not available")

    if positive_prompt.strip() == "":
        raise Error("You must enter a prompt")

    KIND = "txt2img"
    EMBEDDINGS_TYPE = ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED

    start = time.perf_counter()
    log = Logger("generate")
    log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")

    loader = Loader()
    loader.load(
        KIND,
        model,
        scheduler,
        deepcache,
        scale,
        use_karras,
        use_refiner,
        progress,
    )

    refiner = loader.refiner
    pipeline = loader.pipeline
    upscaler = loader.upscaler

    # Probably a typo in the config
    if pipeline is None:
        raise Error(f"Error loading {model}")

    # Prompt embeddings for base and refiner
    compel_1 = Compel(
        text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
        tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2],
        requires_pooled=[False, True],
        returned_embeddings_type=EMBEDDINGS_TYPE,
        dtype_for_device_getter=lambda _: pipeline.dtype,
        device=pipeline.device,
    )
    compel_2 = Compel(
        text_encoder=[pipeline.text_encoder_2],
        tokenizer=[pipeline.tokenizer_2],
        requires_pooled=[True],
        returned_embeddings_type=EMBEDDINGS_TYPE,
        dtype_for_device_getter=lambda _: pipeline.dtype,
        device=pipeline.device,
    )

    # https://pytorch.org/docs/stable/generated/torch.manual_seed.html
    if seed is None or seed < 0:
        seed = int(datetime.now().timestamp() * 1e6) % (2**64)

    # Increment the seed after each iteration
    images = []
    current_seed = seed

    for i in range(num_images):
        try:
            generator = torch.Generator(device=pipeline.device).manual_seed(current_seed)
            conditioning_1, pooled_1 = compel_1([positive_prompt, negative_prompt])
            conditioning_2, pooled_2 = compel_2([positive_prompt, negative_prompt])
        except PromptParser.ParsingException:
            raise Error("Invalid prompt")

        pipeline_output_type, refiner_output_type = get_output_types(scale, use_refiner)

        pipeline_kwargs = {
            "width": width,
            "height": height,
            "denoising_end": 0.8 if use_refiner else None,
            "generator": generator,
            "output_type": pipeline_output_type,
            "guidance_scale": guidance_scale,
            "num_inference_steps": inference_steps,
            "prompt_embeds": conditioning_1[0:1],
            "pooled_prompt_embeds": pooled_1[0:1],
            "negative_prompt_embeds": conditioning_1[1:2],
            "negative_pooled_prompt_embeds": pooled_1[1:2],
        }

        refiner_kwargs = {
            "denoising_start": 0.8,
            "generator": generator,
            "output_type": refiner_output_type,
            "guidance_scale": guidance_scale,
            "num_inference_steps": inference_steps,
            "prompt_embeds": conditioning_2[0:1],
            "pooled_prompt_embeds": pooled_2[0:1],
            "negative_prompt_embeds": conditioning_2[1:2],
            "negative_pooled_prompt_embeds": pooled_2[1:2],
        }

        image = pipeline(**pipeline_kwargs).images[0]

        if use_refiner:
            refiner_kwargs["image"] = image
            image = refiner(**refiner_kwargs).images[0]

        # Use a tuple so gallery images get captions
        images.append((image, str(current_seed)))
        current_seed += 1

    # Upscale
    if scale > 1:
        with timer(f"Upscaling {num_images} images {scale}x", logger=log.info):
            for i, image in enumerate(images):
                image = upscaler.predict(image[0])
                seed = images[i][1]
                images[i] = (image, seed)

    # Flush cache after generating
    cuda_collect()

    end = time.perf_counter()
    msg = f"Generated {len(images)} image{'s' if len(images) > 1 else ''} in {end - start:.2f}s"
    log.info(msg)

    if Info:
        Info(msg)

    return images