Nikhil0987 commited on
Commit
4fa2b26
1 Parent(s): 565579e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -7
app.py CHANGED
@@ -1,16 +1,33 @@
 
1
  import torch
2
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
3
  from huggingface_hub import hf_hub_download
4
  from safetensors.torch import load_file
5
 
 
6
  base = "stabilityai/stable-diffusion-xl-base-1.0"
7
  repo = "ByteDance/SDXL-Lightning"
8
- ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
9
 
10
- # Load model.
11
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
12
- unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
13
- pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
 
 
 
 
14
 
15
- # Ensure sampler uses "trailing" timesteps.
16
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
  import torch
3
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
4
  from huggingface_hub import hf_hub_download
5
  from safetensors.torch import load_file
6
 
7
+ # Model Path/Repo Information
8
  base = "stabilityai/stable-diffusion-xl-base-1.0"
9
  repo = "ByteDance/SDXL-Lightning"
10
+ ckpt = "sdxl_lightning_4step_unet.safetensors"
11
 
12
+ # Load model (Executed only once for efficiency)
13
+ @st.cache_resource
14
+ def load_sdxl_pipeline():
15
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
16
+ unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
17
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
18
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
19
+ return pipe
20
 
21
+ # Streamlit UI
22
+ st.title("Stable Diffusion XL Image Generation")
23
+ prompt = st.text_input("Enter your image prompt:")
24
+
25
+ if st.button("Generate Image"):
26
+ if not prompt:
27
+ st.warning("Please enter a prompt.")
28
+ else:
29
+ pipe = load_sdxl_pipeline() # Load the pipeline from cache
30
+ with torch.no_grad():
31
+ image = pipe(prompt).images[0]
32
+
33
+ st.image(image)