Spaces:
Runtime error
Runtime error
update
Browse files- KBlueLeaf-Kohaku-XL-Epsilon-rev3.code-workspace +8 -0
- app.py +64 -1
KBlueLeaf-Kohaku-XL-Epsilon-rev3.code-workspace
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"folders": [
|
3 |
+
{
|
4 |
+
"path": "."
|
5 |
+
}
|
6 |
+
],
|
7 |
+
"settings": {}
|
8 |
+
}
|
app.py
CHANGED
@@ -1,3 +1,66 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from diffusers import DPMSolverMultistepScheduler, AutoencoderKL, UNet2DConditionModel
|
3 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
4 |
+
import torch
|
5 |
+
from tqdm.auto import tqdm
|
6 |
+
from time import time
|
7 |
+
from PIL import Image
|
8 |
|
9 |
+
vae = AutoencoderKL.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="vae")
|
10 |
+
tokenizer = CLIPTokenizer.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="tokenizer")
|
11 |
+
textEncoder = CLIPTextModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="text_encoder")
|
12 |
+
unet = UNet2DConditionModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="unet")
|
13 |
+
scheduler = DPMSolverMultistepScheduler.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="scheduler")
|
14 |
+
|
15 |
+
torchDevice = "cuda"
|
16 |
+
vae.to(torchDevice)
|
17 |
+
textEncoder.to(torchDevice)
|
18 |
+
unet.to(torchDevice)
|
19 |
+
|
20 |
+
def generate(prompt: str, negativePrompt: str, steps: int, cfg: float, seed: int, randomized: bool, width: int, height: int):
|
21 |
+
generator = torch.manual_seed(time())
|
22 |
+
if randomized:
|
23 |
+
seed = torch.randint(10000, 9223372036854776000, (1,))[0]
|
24 |
+
batchSize = len(prompt)
|
25 |
+
textInput = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
26 |
+
with torch.no_grad():
|
27 |
+
textEmbeddings = textEncoder(textInput.input_ids.to(torchDevice), attention_mask=textInput.attention_mask.to(torchDevice))[0]
|
28 |
+
maxLength = textInput.input_ids.shape[-1]
|
29 |
+
unconditionedInput = tokenizer([""] * batchSize, padding="max_length", max_length=maxLength, return_tensors="pt")
|
30 |
+
unconditionedEmbeddings = textEncoder(unconditionedInput.input_ids.to(torchDevice))[0]
|
31 |
+
textEmbeddings = torch.cat([unconditionedEmbeddings, textEmbeddings])
|
32 |
+
|
33 |
+
latents = torch.randn((batchSize, unet.config.in_channels, height // 8, width // 8), generator=generator, device=torchDevice)
|
34 |
+
latents = latents * scheduler.init_noise_sigma
|
35 |
+
|
36 |
+
scheduler.set_timesteps(steps)
|
37 |
+
for t in tqdm(scheduler.timesteps):
|
38 |
+
latentModelInput = torch.cat([latents] * 2)
|
39 |
+
latentModelInput = scheduler.scale_model_input(latentModelInput, timestep=t)
|
40 |
+
with torch.no_grad():
|
41 |
+
noisePred = unet(latentModelInput, t, encoder_hidden_states=textEmbeddings).sample
|
42 |
+
unconditionedNoisePred, noisePredText = noisePred.chunk(2)
|
43 |
+
noisePred = unconditionedNoisePred + cfg * (noisePredText - unconditionedNoisePred)
|
44 |
+
latents = scheduler.step(noisePred, t, latents).prev_sample
|
45 |
+
|
46 |
+
latents = 1 / 0.18215 * latents
|
47 |
+
with torch.no_grad():
|
48 |
+
image = vae.decode(latents).sample
|
49 |
+
image = (image / 2 + 0.5).clamp(0, 1).squeeze()
|
50 |
+
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
|
51 |
+
images = (image * 255).round().astype("uint8")
|
52 |
+
return Image.fromarray(images)
|
53 |
+
|
54 |
+
interface = gr.Interface(fn=generate, inputs=[
|
55 |
+
gr.Textbox(lines=3, placeholder="Prompt is here...", label="Prompt"),
|
56 |
+
gr.Textbox(lines=3, placeholder="Negative prompt is here...", label="Negative Prompt"),
|
57 |
+
gr.Slider(0, 1000, step=1, label="Steps", value=20),
|
58 |
+
gr.Slider(0, 50, step=0.1, label="CFG Scale", value=8),
|
59 |
+
gr.Number(label="Seed", value=0),
|
60 |
+
gr.Checkbox(label="Randomize Seed", value=True),
|
61 |
+
gr.Slider(256, 999999, step=64, label="Width", value=512),
|
62 |
+
gr.Slider(256, 999999, step=64, label="Height", value=512),
|
63 |
+
], outputs="image")
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
interface.launch()
|