PeterL1n commited on
Commit
6619b46
1 Parent(s): 0b120d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -2,6 +2,9 @@ import gradio as gr
2
  import torch
3
  import spaces
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
 
 
 
5
  from huggingface_hub import hf_hub_download
6
  from safetensors.torch import load_file
7
  from PIL import Image
@@ -27,6 +30,11 @@ unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0])))
27
  pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
28
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
29
 
 
 
 
 
 
30
  with open("filter.txt") as f:
31
  filter_words = {word for word in f.read().split("\n") if word}
32
 
@@ -37,7 +45,7 @@ def generate(prompt, option, progress=gr.Progress()):
37
  print(prompt, option)
38
  ckpt, step = opts[option]
39
  if any(word in prompt for word in filter_words):
40
- gr.Warning("Safety checker triggered.")
41
  print(f"Safety checker triggered on prompt: {prompt}")
42
  return Image.new("RGB", (512, 512))
43
  progress((0, step))
@@ -49,17 +57,18 @@ def generate(prompt, option, progress=gr.Progress()):
49
  def inference_callback(p, i, t, kwargs):
50
  progress((i+1, step))
51
  return kwargs
52
- results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback)
53
- nsfw_content_detected = (
54
- results.nsfw_content_detected[0]
55
- if "nsfw_content_detected" in results
56
- else False
 
 
57
  )
58
- if nsfw_content_detected:
59
- gr.Warning("Safety checker triggered.")
60
  print(f"Safety checker triggered on prompt: {prompt}")
61
- return Image.new("RGB", (512, 512))
62
- return results.images[0]
63
 
64
  with gr.Blocks(css="style.css") as demo:
65
  gr.HTML(
 
2
  import torch
3
  import spaces
4
  from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
5
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
6
+ from diffusers.image_processor import VaeImageProcessor
7
+ from transformers import CLIPImageProcessor
8
  from huggingface_hub import hf_hub_download
9
  from safetensors.torch import load_file
10
  from PIL import Image
 
30
  pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16").to(device, dtype)
31
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
32
 
33
+ # Safety checker.
34
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device, dtype)
35
+ feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
36
+ image_processor = VaeImageProcessor(vae_scale_factor=8)
37
+
38
  with open("filter.txt") as f:
39
  filter_words = {word for word in f.read().split("\n") if word}
40
 
 
45
  print(prompt, option)
46
  ckpt, step = opts[option]
47
  if any(word in prompt for word in filter_words):
48
+ gr.Warning("Safety checker triggered. Image may contain violent or sexual content.")
49
  print(f"Safety checker triggered on prompt: {prompt}")
50
  return Image.new("RGB", (512, 512))
51
  progress((0, step))
 
57
  def inference_callback(p, i, t, kwargs):
58
  progress((i+1, step))
59
  return kwargs
60
+ results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback, output_type="pil")
61
+
62
+ # Safety check
63
+ feature_extractor_input = image_processor.postprocess(results.images, output_type="pil")
64
+ safety_checker_input = feature_extractor(feature_extractor_input, return_tensors="pt")
65
+ images, has_nsfw_concept = safety_checker(
66
+ images=results.images, clip_input=safety_checker_input.pixel_values.to(device, dtype)
67
  )
68
+ if has_nsfw_concept[0]:
69
+ gr.Warning("Safety checker triggered. Image may contain violent or sexual content.")
70
  print(f"Safety checker triggered on prompt: {prompt}")
71
+ return images[0]
 
72
 
73
  with gr.Blocks(css="style.css") as demo:
74
  gr.HTML(