File size: 1,283 Bytes
4fa2b26
dffdfdb
 
 
 
 
4fa2b26
dffdfdb
 
4fa2b26
dffdfdb
4fa2b26
 
 
 
 
 
 
 
dffdfdb
4fa2b26
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import streamlit as st
import torch
from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

# Model Path/Repo Information
base = "stabilityai/stable-diffusion-xl-base-1.0"
repo = "ByteDance/SDXL-Lightning"
ckpt = "sdxl_lightning_4step_unet.safetensors" 

# Load model (Executed only once for efficiency)
@st.cache_resource
def load_sdxl_pipeline():
    unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
    unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
    pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
    return pipe

# Streamlit UI
st.title("Stable Diffusion XL Image Generation")
prompt = st.text_input("Enter your image prompt:")

if st.button("Generate Image"):
    if not prompt:
        st.warning("Please enter a prompt.")
    else:
        pipe = load_sdxl_pipeline()  # Load the pipeline from cache
        with torch.no_grad():
            image = pipe(prompt).images[0] 

        st.image(image)