PeterL1n commited on
Commit
88dc089
1 Parent(s): c811b57

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -3
app.py CHANGED
@@ -6,6 +6,8 @@ from huggingface_hub import hf_hub_download
6
  from safetensors.torch import load_file
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
9
  base = "stabilityai/stable-diffusion-xl-base-1.0"
10
  repo = "ByteDance/SDXL-Lightning"
11
  opts = {
@@ -16,9 +18,9 @@ opts = {
16
  }
17
 
18
  step_loaded = 4
19
- unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16)
20
  unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0]), device=device))
21
- pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to(device)
22
 
23
  @spaces.GPU(enable_queue=True)
24
  def generate_image(prompt, option):
@@ -33,7 +35,7 @@ def generate_image(prompt, option):
33
  with gr.Blocks() as demo:
34
  gr.HTML(
35
  "<h1><center>SDXL-Lightning</center></h1>" +
36
- "<p><center>Lightning-fast text-to-image generation.</center></p>" +
37
  "<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
38
  )
39
 
 
6
  from safetensors.torch import load_file
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
10
+
11
  base = "stabilityai/stable-diffusion-xl-base-1.0"
12
  repo = "ByteDance/SDXL-Lightning"
13
  opts = {
 
18
  }
19
 
20
  step_loaded = 4
21
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, dtype)
22
  unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0]), device=device))
23
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device)
24
 
25
  @spaces.GPU(enable_queue=True)
26
  def generate_image(prompt, option):
 
35
  with gr.Blocks() as demo:
36
  gr.HTML(
37
  "<h1><center>SDXL-Lightning</center></h1>" +
38
+ "<p><center>Lightning-fast text-to-image generation</center></p>" +
39
  "<p><center><a href='https://huggingface.co/ByteDance/SDXL-Lightning'>https://huggingface.co/ByteDance/SDXL-Lightning</a></center></p>"
40
  )
41