Update app.py
Browse files
app.py
CHANGED
@@ -6,9 +6,9 @@ from safetensors.torch import load_file
|
|
6 |
|
7 |
# Define a function to generate the image
|
8 |
def generate_image(prompt, num_inference_steps):
|
9 |
-
base = "stable-diffusion"
|
10 |
-
repo = "
|
11 |
-
ckpt = "
|
12 |
|
13 |
# Load model.
|
14 |
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
|
|
|
6 |
|
7 |
# Define a function to generate the image
|
8 |
def generate_image(prompt, num_inference_steps):
|
9 |
+
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
10 |
+
repo = "ByteDance/SDXL-Lightning"
|
11 |
+
ckpt = "sdxl_lightning_2step_unet.safetensors" # Use the correct ckpt for your step setting!
|
12 |
|
13 |
# Load model.
|
14 |
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
|