File size: 5,372 Bytes
cb5daed
7736f5f
4d6f2bc
 
 
48c31e7
dffd0bb
9769856
 
ca5a1e4
9769856
aafe7f2
9769856
 
 
 
4d6f2bc
9769856
4d6f2bc
9769856
 
 
4d6f2bc
9769856
 
 
48c31e7
 
ca2f5d2
af07f4b
60849d7
9769856
6829539
4d6f2bc
9769856
 
 
4d6f2bc
1128e78
5c4e8c1
1128e78
9769856
 
48c31e7
9769856
 
 
4470520
9769856
 
98afd85
4d5d84d
4d6f2bc
ca2f5d2
 
9769856
 
6829539
9769856
6829539
9769856
 
 
4470520
6829539
9769856
6829539
 
9769856
 
039ff6d
9769856
6829539
4d6f2bc
9769856
039ff6d
4470520
9769856
 
 
 
ca2f5d2
 
79ce657
 
 
9769856
ca2f5d2
 
 
6829539
51fab87
6829539
9769856
 
1a7f234
9769856
6829539
9769856
 
6829539
4d6f2bc
9769856
 
 
 
 
6829539
 
79ce657
6829539
dffd0bb
9769856
4470520
972fe7d
6829539
dffd0bb
f70898c
6829539
 
 
 
 
4470520
6829539
 
f70898c
6829539
 
 
9769856
6829539
9769856
6829539
98afd85
9769856
98afd85
 
9769856
98afd85
9769856
 
 
6829539
 
9769856
 
4470520
6829539
ca2f5d2
9769856
6829539
51fab87
9e8b99d
9769856
9e8b99d
 
9769856
 
039ff6d
069fc81
 
aafe7f2
51fab87
6829539
aafe7f2
51fab87
9769856
 
 
6829539
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import os
import time
from datetime import datetime

import torch
from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
from compel.prompt_parser import PromptParser
from gradio import Error, Info, Progress
from spaces import GPU, config

from .loader import get_loader
from .logger import Logger
from .utils import annotate_image, cuda_collect, resize_image, timer


@GPU
def generate(
    positive_prompt="",
    negative_prompt="",
    image_input=None,
    controlnet_input=None,
    ip_adapter_input=None,
    seed=None,
    model="XpucT/Reliberate",
    scheduler="UniPC",
    controlnet_annotator="canny",
    width=512,
    height=512,
    guidance_scale=6.0,
    inference_steps=40,
    denoising_strength=0.8,
    deepcache_interval=1,
    scale=1,
    num_images=1,
    use_karras=False,
    use_ip_adapter_face=False,
    _=Progress(track_tqdm=True),
):
    if not torch.cuda.is_available():
        raise Error("CUDA not available")

    if positive_prompt.strip() == "":
        raise Error("You must enter a prompt")

    start = time.perf_counter()
    log = Logger("generate")
    log.info(f"Generating {num_images} image{'s' if num_images > 1 else ''}...")

    KIND = "img2img" if image_input is not None else "txt2img"
    KIND = f"controlnet_{KIND}" if controlnet_input is not None else KIND

    EMBEDDINGS_TYPE = ReturnedEmbeddingsType.LAST_HIDDEN_STATES_NORMALIZED

    FAST_NEGATIVE = "<fast_negative>" in negative_prompt

    if ip_adapter_input:
        IP_KIND = "full-face" if use_ip_adapter_face else "plus"
    else:
        IP_KIND = ""

    # ZeroGPU is serverless so you want ephemeral instances
    # You want a singleton on localhost so the pipeline stays in memory
    loader = get_loader(singleton=not config.Config.zero_gpu)
    loader.load(
        KIND,
        IP_KIND,
        model,
        scheduler,
        controlnet_annotator,
        deepcache_interval,
        scale,
        use_karras,
    )

    pipeline = loader.pipeline
    upscaler = loader.upscaler

    # Probably a typo in the config
    if pipeline is None:
        raise Error(f"Error loading {model}")

    # Load fast negative embedding
    if FAST_NEGATIVE:
        embeddings_dir = os.path.abspath(
            os.path.join(os.path.dirname(__file__), "..", "embeddings")
        )
        pipeline.load_textual_inversion(
            pretrained_model_name_or_path=f"{embeddings_dir}/fast_negative.pt",
            token="<fast_negative>",
        )

    # Embed prompts with weights
    compel = Compel(
        device=pipeline.device,
        tokenizer=pipeline.tokenizer,
        truncate_long_prompts=False,
        text_encoder=pipeline.text_encoder,
        returned_embeddings_type=EMBEDDINGS_TYPE,
        dtype_for_device_getter=lambda _: pipeline.dtype,
        textual_inversion_manager=DiffusersTextualInversionManager(pipeline),
    )

    # https://pytorch.org/docs/stable/generated/torch.manual_seed.html
    if seed is None or seed < 0:
        seed = int(datetime.now().timestamp() * 1_000_000) % (2**64)

    # Increment the seed after each iteration
    images = []
    current_seed = seed

    for i in range(num_images):
        try:
            generator = torch.Generator(device=pipeline.device).manual_seed(current_seed)
            positive_embeds, negative_embeds = compel.pad_conditioning_tensors_to_same_length(
                [compel(positive_prompt), compel(negative_prompt)]
            )
        except PromptParser.ParsingException:
            raise Error("Invalid prompt")

        kwargs = {
            "width": width,
            "height": height,
            "generator": generator,
            "prompt_embeds": positive_embeds,
            "guidance_scale": guidance_scale,
            "num_inference_steps": inference_steps,
            "negative_prompt_embeds": negative_embeds,
            "output_type": "np" if scale > 1 else "pil",
        }

        if KIND == "img2img" or KIND == "controlnet_img2img":
            kwargs["strength"] = denoising_strength
            kwargs["image"] = resize_image(image_input, (width, height))

        if KIND == "controlnet_txt2img":
            kwargs["image"] = annotate_image(controlnet_input, controlnet_annotator)

        if KIND == "controlnet_img2img":
            kwargs["control_image"] = annotate_image(controlnet_input, controlnet_annotator)

        if IP_KIND:
            # No size means preserve aspect ratio
            kwargs["ip_adapter_image"] = resize_image(ip_adapter_input)

        try:
            image = pipeline(**kwargs).images[0]
            images.append((image, str(current_seed)))  # tuple with seed for gallery caption
            current_seed += 1
        finally:
            if FAST_NEGATIVE:
                pipeline.unload_textual_inversion()

    # Upscale
    if scale > 1:
        with timer(f"Upscaling {num_images} images {scale}x", logger=log.info):
            for i, image in enumerate(images):
                image = upscaler.predict(image[0])
                seed = images[i][1]
                images[i] = (image, seed)  # tuple again

    end = time.perf_counter()
    msg = f"Generating {len(images)} image{'s' if len(images) > 1 else ''} took {end - start:.2f}s"
    log.info(msg)

    if Info:
        Info(msg)

    # Flush cache before returning
    cuda_collect()

    return images