Spaces:
Runtime error
Runtime error
File size: 3,354 Bytes
d0a8ced 1feb9c4 d0a8ced 1feb9c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
import gradio as gr
from diffusers import DPMSolverMultistepScheduler, AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer
import torch
from tqdm.auto import tqdm
from time import time
from PIL import Image
vae = AutoencoderKL.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="vae")
tokenizer = CLIPTokenizer.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="tokenizer")
textEncoder = CLIPTextModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="text_encoder")
unet = UNet2DConditionModel.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="unet")
scheduler = DPMSolverMultistepScheduler.from_pretrained("KBlueLeaf/Kohaku-XL-Epsilon-rev3", subfolder="scheduler")
torchDevice = "cuda"
vae.to(torchDevice)
textEncoder.to(torchDevice)
unet.to(torchDevice)
def generate(prompt: str, negativePrompt: str, steps: int, cfg: float, seed: int, randomized: bool, width: int, height: int):
generator = torch.manual_seed(time())
if randomized:
seed = torch.randint(10000, 9223372036854776000, (1,))[0]
batchSize = len(prompt)
textInput = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
textEmbeddings = textEncoder(textInput.input_ids.to(torchDevice), attention_mask=textInput.attention_mask.to(torchDevice))[0]
maxLength = textInput.input_ids.shape[-1]
unconditionedInput = tokenizer([""] * batchSize, padding="max_length", max_length=maxLength, return_tensors="pt")
unconditionedEmbeddings = textEncoder(unconditionedInput.input_ids.to(torchDevice))[0]
textEmbeddings = torch.cat([unconditionedEmbeddings, textEmbeddings])
latents = torch.randn((batchSize, unet.config.in_channels, height // 8, width // 8), generator=generator, device=torchDevice)
latents = latents * scheduler.init_noise_sigma
scheduler.set_timesteps(steps)
for t in tqdm(scheduler.timesteps):
latentModelInput = torch.cat([latents] * 2)
latentModelInput = scheduler.scale_model_input(latentModelInput, timestep=t)
with torch.no_grad():
noisePred = unet(latentModelInput, t, encoder_hidden_states=textEmbeddings).sample
unconditionedNoisePred, noisePredText = noisePred.chunk(2)
noisePred = unconditionedNoisePred + cfg * (noisePredText - unconditionedNoisePred)
latents = scheduler.step(noisePred, t, latents).prev_sample
latents = 1 / 0.18215 * latents
with torch.no_grad():
image = vae.decode(latents).sample
image = (image / 2 + 0.5).clamp(0, 1).squeeze()
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
images = (image * 255).round().astype("uint8")
return Image.fromarray(images)
interface = gr.Interface(fn=generate, inputs=[
gr.Textbox(lines=3, placeholder="Prompt is here...", label="Prompt"),
gr.Textbox(lines=3, placeholder="Negative prompt is here...", label="Negative Prompt"),
gr.Slider(0, 1000, step=1, label="Steps", value=20),
gr.Slider(0, 50, step=0.1, label="CFG Scale", value=8),
gr.Number(label="Seed", value=0),
gr.Checkbox(label="Randomize Seed", value=True),
gr.Slider(256, 999999, step=64, label="Width", value=512),
gr.Slider(256, 999999, step=64, label="Height", value=512),
], outputs="image")
if __name__ == "__main__":
interface.launch() |