Quick Local Web Viewer

#21
by Metricon - opened

For those who might find it helpful, here is a quick Python local web interface for easy generation and viewing:

This assumes a local file for repositories at "D:/install/stable-diffusion-xl-base-1.0" and "D:/install/SDXL-Lightning". To auto pull from online instead, just remove os.environ["TRANSFORMERS_OFFLINE"] = '1' and replace base = "D:/install/stable-diffusion-xl-base-1.0" with base = "stabilityai/stable-diffusion-xl-base-1.0" and repo = "D:/install/SDXL-Lightning" with repo = "ByteDance/SDXL-Lightning" and change unet.load_state_dict(load_file(f"{repo}/{ckpt}", device="cuda")) with unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))

Also, replace ckpt = "sdxl_lightning_4step_unet.safetensors" with ckpt = "sdxl_lightning_8step_unet.safetensors" and change num_inference_steps=8 to use 8 steps instead of 4.

import os
import torch
import numpy as np
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import gradio as gr

os.environ["TRANSFORMERS_OFFLINE"] = '1'
base = "D:/install/stable-diffusion-xl-base-1.0"
repo = "D:/install/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors"

unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
unet.load_state_dict(load_file(f"{repo}/{ckpt}", device="cuda"))
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")

def generate_image(prompt, seed=None):
    if seed is not None:
        generator = torch.Generator("cuda").manual_seed(seed)
    else:
        generator = torch.Generator("cuda")
    
    with torch.no_grad():
        output = pipe(prompt, num_inference_steps=4, guidance_scale=0, generator=generator).images[0]
    return output

def generate_random_seed():
    return np.random.randint(0, high=2**31 - 1)

with gr.Blocks() as demo:
    with gr.Row():
        prompt_input = gr.Textbox(label="Prompt", lines=2, interactive=True, placeholder="Type something...")
        with gr.Column():
            seed_input = gr.Number(label="Seed", value=generate_random_seed(), precision=0, step=1, interactive=True)  # Initialize with random seed
            random_seed_btn = gr.Button("Generate Random Seed")
    
    generate_btn = gr.Button("Generate Image")
    output = gr.Image(label="Generated Image", width=1024, height=1024)
    random_seed_btn.click(fn=generate_random_seed, inputs=[], outputs=seed_input)

    prompt_input.change(fn=generate_image, inputs=[prompt_input, seed_input], outputs=output)
    seed_input.change(fn=generate_image, inputs=[prompt_input, seed_input], outputs=output)
    generate_btn.click(fn=generate_image, inputs=[prompt_input, seed_input], outputs=output)

demo.launch(inbrowser=True)

Sign up or log in to comment