File size: 3,354 Bytes
d0a8ced
1feb9c4
 
 
 
 
 
d0a8ced
1feb9c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import gradio as gr
from diffusers import DPMSolverMultistepScheduler, AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
import torch
from tqdm.auto import tqdm
from time import time
from PIL import Image

vae = AutoencoderKL.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="vae")
tokenizer = CLIPTokenizer.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="tokenizer")
textEncoder = CLIPTextModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="text_encoder")
unet = UNet2DConditionModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="unet")
scheduler = DPMSolverMultistepScheduler.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="scheduler")

torchDevice = "cuda"
vae.to(torchDevice)
textEncoder.to(torchDevice)
unet.to(torchDevice)

def generate(prompt: str, negativePrompt: str, steps: int, cfg: float, seed: int, randomized: bool, width: int, height: int):
  generator = torch.manual_seed(time())
  if randomized:
    seed = torch.randint(10000, 9223372036854776000, (1,))[0]
  batchSize = len(prompt)
  textInput = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
  with torch.no_grad():
    textEmbeddings = textEncoder(textInput.input_ids.to(torchDevice), attention_mask=textInput.attention_mask.to(torchDevice))[0]
  maxLength = textInput.input_ids.shape[-1]
  unconditionedInput = tokenizer([""] * batchSize, padding="max_length", max_length=maxLength, return_tensors="pt")
  unconditionedEmbeddings = textEncoder(unconditionedInput.input_ids.to(torchDevice))[0]
  textEmbeddings = torch.cat([unconditionedEmbeddings, textEmbeddings])

  latents = torch.randn((batchSize, unet.config.in_channels, height // 8, width // 8), generator=generator, device=torchDevice)
  latents = latents * scheduler.init_noise_sigma

  scheduler.set_timesteps(steps)
  for t in tqdm(scheduler.timesteps):
    latentModelInput = torch.cat([latents] * 2)
    latentModelInput = scheduler.scale_model_input(latentModelInput, timestep=t)
    with torch.no_grad():
      noisePred = unet(latentModelInput, t, encoder_hidden_states=textEmbeddings).sample
    unconditionedNoisePred, noisePredText = noisePred.chunk(2)
    noisePred = unconditionedNoisePred + cfg * (noisePredText - unconditionedNoisePred)
    latents = scheduler.step(noisePred, t, latents).prev_sample

  latents = 1 / 0.18215 * latents
  with torch.no_grad():
    image = vae.decode(latents).sample
  image = (image / 2 + 0.5).clamp(0, 1).squeeze()
  image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
  images = (image * 255).round().astype("uint8")
  return Image.fromarray(images)

interface = gr.Interface(fn=generate, inputs=[
    gr.Textbox(lines=3, placeholder="Prompt is here...", label="Prompt"), 
    gr.Textbox(lines=3, placeholder="Negative prompt is here...", label="Negative Prompt"),
    gr.Slider(0, 1000, step=1, label="Steps", value=20),
    gr.Slider(0, 50, step=0.1, label="CFG Scale", value=8),
    gr.Number(label="Seed", value=0),
    gr.Checkbox(label="Randomize Seed", value=True),
    gr.Slider(256, 999999, step=64, label="Width", value=512),
    gr.Slider(256, 999999, step=64, label="Height", value=512),
  ], outputs="image")

if __name__ == "__main__":
  interface.launch()