Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -13,13 +13,14 @@ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
|
|
13 |
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
14 |
repo = "ByteDance/SDXL-Lightning"
|
15 |
checkpoints = {
|
16 |
-
"
|
17 |
-
"
|
18 |
-
"
|
19 |
-
"
|
20 |
}
|
21 |
loaded = None
|
22 |
|
|
|
23 |
# Ensure model and scheduler are initialized in GPU-enabled function
|
24 |
if torch.cuda.is_available():
|
25 |
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
|
@@ -38,32 +39,39 @@ if SAFETY_CHECKER:
|
|
38 |
def check_nsfw_images(
|
39 |
images: list[Image.Image],
|
40 |
) -> tuple[list[Image.Image], list[bool]]:
|
41 |
-
safety_checker_input = feature_extractor(images
|
42 |
has_nsfw_concepts = safety_checker(
|
43 |
-
images=images,
|
44 |
-
clip_input=safety_checker_input.pixel_values.to("cuda")
|
45 |
)
|
46 |
|
47 |
-
return images, has_nsfw_concepts
|
48 |
|
|
|
49 |
@spaces.GPU(enable_queue=True)
|
50 |
def generate_image(prompt, ckpt):
|
51 |
global loaded
|
52 |
-
|
|
|
|
|
|
|
53 |
|
54 |
if loaded != num_inference_steps:
|
55 |
-
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps
|
56 |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
|
57 |
loaded = num_inference_steps
|
58 |
-
|
59 |
-
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=
|
60 |
|
61 |
if SAFETY_CHECKER:
|
62 |
images, has_nsfw_concepts = check_nsfw_images(results.images)
|
63 |
if any(has_nsfw_concepts):
|
64 |
-
|
|
|
|
|
65 |
return results.images[0]
|
66 |
|
|
|
67 |
description = """
|
68 |
🌌 Engage in the exploration of galaxies with the advanced SDXL-Lightning model, a creation of ByteDance capable of transforming your textual descriptions into vivid images at warp speed. This is a joint venture initiated by Starfleet, enabling creative minds to visualize the uncharted territories of space. 🚀 Link to model: [ByteDance/SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)
|
69 |
"""
|
|
|
13 |
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
14 |
repo = "ByteDance/SDXL-Lightning"
|
15 |
checkpoints = {
|
16 |
+
"1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
|
17 |
+
"2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
|
18 |
+
"4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
|
19 |
+
"8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
|
20 |
}
|
21 |
loaded = None
|
22 |
|
23 |
+
|
24 |
# Ensure model and scheduler are initialized in GPU-enabled function
|
25 |
if torch.cuda.is_available():
|
26 |
pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
|
|
|
39 |
def check_nsfw_images(
|
40 |
images: list[Image.Image],
|
41 |
) -> tuple[list[Image.Image], list[bool]]:
|
42 |
+
safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
|
43 |
has_nsfw_concepts = safety_checker(
|
44 |
+
images=[images],
|
45 |
+
clip_input=safety_checker_input.pixel_values.to("cuda")
|
46 |
)
|
47 |
|
48 |
+
return images, has_nsfw_concepts
|
49 |
|
50 |
+
# Function
|
51 |
@spaces.GPU(enable_queue=True)
|
52 |
def generate_image(prompt, ckpt):
|
53 |
global loaded
|
54 |
+
print(prompt, ckpt)
|
55 |
+
|
56 |
+
checkpoint = checkpoints[ckpt][0]
|
57 |
+
num_inference_steps = checkpoints[ckpt][1]
|
58 |
|
59 |
if loaded != num_inference_steps:
|
60 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
|
61 |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
|
62 |
loaded = num_inference_steps
|
63 |
+
|
64 |
+
results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
|
65 |
|
66 |
if SAFETY_CHECKER:
|
67 |
images, has_nsfw_concepts = check_nsfw_images(results.images)
|
68 |
if any(has_nsfw_concepts):
|
69 |
+
gr.Warning("NSFW content detected.")
|
70 |
+
return Image.new("RGB", (512, 512))
|
71 |
+
return images[0]
|
72 |
return results.images[0]
|
73 |
|
74 |
+
# Gradio Interface
|
75 |
description = """
|
76 |
🌌 Engage in the exploration of galaxies with the advanced SDXL-Lightning model, a creation of ByteDance capable of transforming your textual descriptions into vivid images at warp speed. This is a joint venture initiated by Starfleet, enabling creative minds to visualize the uncharted territories of space. 🚀 Link to model: [ByteDance/SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)
|
77 |
"""
|