Kabatubare commited on
Commit
2fc58e2
·
verified ·
1 Parent(s): c2ced1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -13
app.py CHANGED
@@ -13,13 +13,14 @@ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1"
13
  base = "stabilityai/stable-diffusion-xl-base-1.0"
14
  repo = "ByteDance/SDXL-Lightning"
15
  checkpoints = {
16
- "Warp 1": ["sdxl_lightning_1step_unet_x0.safetensors", 1],
17
- "Warp 2": ["sdxl_lightning_2step_unet.safetensors", 2],
18
- "Warp 4": ["sdxl_lightning_4step_unet.safetensors", 4],
19
- "Warp 8": ["sdxl_lightning_8step_unet.safetensors", 8],
20
  }
21
  loaded = None
22
 
 
23
  # Ensure model and scheduler are initialized in GPU-enabled function
24
  if torch.cuda.is_available():
25
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
@@ -38,32 +39,39 @@ if SAFETY_CHECKER:
38
  def check_nsfw_images(
39
  images: list[Image.Image],
40
  ) -> tuple[list[Image.Image], list[bool]]:
41
- safety_checker_input = feature_extractor(images=[image.convert("RGB") for image in images], return_tensors="pt").to("cuda")
42
  has_nsfw_concepts = safety_checker(
43
- images=images,
44
- clip_input=safety_checker_input.pixel_values.to("cuda"),
45
  )
46
 
47
- return images, has_nsfw_concepts.bool().tolist()
48
 
 
49
  @spaces.GPU(enable_queue=True)
50
  def generate_image(prompt, ckpt):
51
  global loaded
52
- checkpoint, num_inference_steps = checkpoints[ckpt]
 
 
 
53
 
54
  if loaded != num_inference_steps:
55
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps == 1 else "epsilon")
56
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
57
  loaded = num_inference_steps
58
-
59
- results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=7.5)
60
 
61
  if SAFETY_CHECKER:
62
  images, has_nsfw_concepts = check_nsfw_images(results.images)
63
  if any(has_nsfw_concepts):
64
- return Image.new("RGB", (512, 512), "black")
 
 
65
  return results.images[0]
66
 
 
67
  description = """
68
  🌌 Engage in the exploration of galaxies with the advanced SDXL-Lightning model, a creation of ByteDance capable of transforming your textual descriptions into vivid images at warp speed. This is a joint venture initiated by Starfleet, enabling creative minds to visualize the uncharted territories of space. 🚀 Link to model: [ByteDance/SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)
69
  """
 
13
  base = "stabilityai/stable-diffusion-xl-base-1.0"
14
  repo = "ByteDance/SDXL-Lightning"
15
  checkpoints = {
16
+ "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1],
17
+ "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2],
18
+ "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4],
19
+ "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8],
20
  }
21
  loaded = None
22
 
23
+
24
  # Ensure model and scheduler are initialized in GPU-enabled function
25
  if torch.cuda.is_available():
26
  pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to("cuda")
 
39
  def check_nsfw_images(
40
  images: list[Image.Image],
41
  ) -> tuple[list[Image.Image], list[bool]]:
42
+ safety_checker_input = feature_extractor(images, return_tensors="pt").to("cuda")
43
  has_nsfw_concepts = safety_checker(
44
+ images=[images],
45
+ clip_input=safety_checker_input.pixel_values.to("cuda")
46
  )
47
 
48
+ return images, has_nsfw_concepts
49
 
50
+ # Function
51
  @spaces.GPU(enable_queue=True)
52
  def generate_image(prompt, ckpt):
53
  global loaded
54
+ print(prompt, ckpt)
55
+
56
+ checkpoint = checkpoints[ckpt][0]
57
+ num_inference_steps = checkpoints[ckpt][1]
58
 
59
  if loaded != num_inference_steps:
60
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if num_inference_steps==1 else "epsilon")
61
  pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device="cuda"))
62
  loaded = num_inference_steps
63
+
64
+ results = pipe(prompt, num_inference_steps=num_inference_steps, guidance_scale=0)
65
 
66
  if SAFETY_CHECKER:
67
  images, has_nsfw_concepts = check_nsfw_images(results.images)
68
  if any(has_nsfw_concepts):
69
+ gr.Warning("NSFW content detected.")
70
+ return Image.new("RGB", (512, 512))
71
+ return images[0]
72
  return results.images[0]
73
 
74
+ # Gradio Interface
75
  description = """
76
  🌌 Engage in the exploration of galaxies with the advanced SDXL-Lightning model, a creation of ByteDance capable of transforming your textual descriptions into vivid images at warp speed. This is a joint venture initiated by Starfleet, enabling creative minds to visualize the uncharted territories of space. 🚀 Link to model: [ByteDance/SDXL-Lightning](https://huggingface.co/ByteDance/SDXL-Lightning)
77
  """