PeterL1n commited on
Commit
0b120d5
1 Parent(s): edf024c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -2
app.py CHANGED
@@ -4,6 +4,7 @@ import spaces
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
5
  from huggingface_hub import hf_hub_download
6
  from safetensors.torch import load_file
 
7
 
8
  assert torch.cuda.is_available()
9
 
@@ -36,7 +37,9 @@ def generate(prompt, option, progress=gr.Progress()):
36
  print(prompt, option)
37
  ckpt, step = opts[option]
38
  if any(word in prompt for word in filter_words):
39
- return None
 
 
40
  progress((0, step))
41
  if step != step_loaded:
42
  print(f"Switching checkpoint from {step_loaded} to {step}")
@@ -46,7 +49,17 @@ def generate(prompt, option, progress=gr.Progress()):
46
  def inference_callback(p, i, t, kwargs):
47
  progress((i+1, step))
48
  return kwargs
49
- return pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback).images[0]
 
 
 
 
 
 
 
 
 
 
50
 
51
  with gr.Blocks(css="style.css") as demo:
52
  gr.HTML(
 
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
5
  from huggingface_hub import hf_hub_download
6
  from safetensors.torch import load_file
7
+ from PIL import Image
8
 
9
  assert torch.cuda.is_available()
10
 
 
37
  print(prompt, option)
38
  ckpt, step = opts[option]
39
  if any(word in prompt for word in filter_words):
40
+ gr.Warning("Safety checker triggered.")
41
+ print(f"Safety checker triggered on prompt: {prompt}")
42
+ return Image.new("RGB", (512, 512))
43
  progress((0, step))
44
  if step != step_loaded:
45
  print(f"Switching checkpoint from {step_loaded} to {step}")
 
49
  def inference_callback(p, i, t, kwargs):
50
  progress((i+1, step))
51
  return kwargs
52
+ results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback)
53
+ nsfw_content_detected = (
54
+ results.nsfw_content_detected[0]
55
+ if "nsfw_content_detected" in results
56
+ else False
57
+ )
58
+ if nsfw_content_detected:
59
+ gr.Warning("Safety checker triggered.")
60
+ print(f"Safety checker triggered on prompt: {prompt}")
61
+ return Image.new("RGB", (512, 512))
62
+ return results.images[0]
63
 
64
  with gr.Blocks(css="style.css") as demo:
65
  gr.HTML(