vittore commited on
Commit
26d6ed3
1 Parent(s): 2a7d65d

Fix description

Browse files
Files changed (1) hide show
  1. app.py +30 -18
app.py CHANGED
@@ -27,9 +27,6 @@ from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionS
27
 
28
  BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
29
 
30
- # Initialize both pipelines
31
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
32
- controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16)
33
 
34
  # Initialize the safety checker conditionally
35
  SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
@@ -39,20 +36,36 @@ device='cuda'
39
  device='cpu'
40
 
41
 
 
 
 
 
 
 
 
 
 
42
  if SAFETY_CHECKER_ENABLED:
43
  safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device)
44
  feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
45
 
46
-
47
- main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
48
- BASE_MODEL,
49
- controlnet=controlnet,
50
- vae=vae,
51
- safety_checker=safety_checker,
52
- feature_extractor=feature_extractor,
53
- torch_dtype=torch.float16,
54
- ).to(device)
55
-
 
 
 
 
 
 
 
56
  # Function to check NSFW images
57
  #def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
58
  # if SAFETY_CHECKER_ENABLED:
@@ -225,15 +238,14 @@ def inference(
225
  with gr.Blocks() as app:
226
  gr.Markdown(
227
  '''
228
- <div style="text-align: center;">
229
  <h1>pattern + prompt = image</h1>
230
- <p style="font-size:16px;">Generate stunning high quality illusion artwork with Stable Diffusion</p>
231
- <p>Illusion Diffusion is back up with a safety checker! Because I have been asked, if you would like to support me, consider using <a href="https://deforum.studio">deforum.studio</a></p>
232
- <p>With big contributions from
233
  <ul>
234
  <li><a href="https://twitter.com/multimodalart">multimodalart</a></li>
235
  <li><a href="https://huggingface.co/monster-labs/control_v1p_sd15_qrcode_monster">Monster Labs QR Control Net</a></li>
236
- <li><a href="https://twitter.com/MrUgleh">MrUgleh</a>/li>
 
237
  </ul>
238
  </div>
239
  '''
 
27
 
28
  BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
29
 
 
 
 
30
 
31
  # Initialize the safety checker conditionally
32
  SAFETY_CHECKER_ENABLED = os.environ.get("SAFETY_CHECKER", "0") == "1"
 
36
  device='cpu'
37
 
38
 
39
+ # Initialize both pipelines
40
+ if device=='cuda':
41
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
42
+ controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster", torch_dtype=torch.float16)
43
+ else:
44
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
45
+ controlnet = ControlNetModel.from_pretrained("monster-labs/control_v1p_sd15_qrcode_monster")
46
+
47
+
48
  if SAFETY_CHECKER_ENABLED:
49
  safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device)
50
  feature_extractor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
51
 
52
+ if device=='cuda':
53
+ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
54
+ BASE_MODEL,
55
+ controlnet=controlnet,
56
+ vae=vae,
57
+ safety_checker=safety_checker,
58
+ feature_extractor=feature_extractor,
59
+ torch_dtype=torch.float16,
60
+ ).to(device)
61
+ else:
62
+ main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
63
+ BASE_MODEL,
64
+ controlnet=controlnet,
65
+ vae=vae,
66
+ safety_checker=safety_checker,
67
+ feature_extractor=feature_extractor
68
+ ).to(device)
69
  # Function to check NSFW images
70
  #def check_nsfw_images(images: list[Image.Image]) -> tuple[list[Image.Image], list[bool]]:
71
  # if SAFETY_CHECKER_ENABLED:
 
238
  with gr.Blocks() as app:
239
  gr.Markdown(
240
  '''
241
+ <div>
242
  <h1>pattern + prompt = image</h1>
243
+ <p>With big contributions from:</p>
 
 
244
  <ul>
245
  <li><a href="https://twitter.com/multimodalart">multimodalart</a></li>
246
  <li><a href="https://huggingface.co/monster-labs/control_v1p_sd15_qrcode_monster">Monster Labs QR Control Net</a></li>
247
+ <li><a href="https://twitter.com/MrUgleh">MrUgleh</a></li>
248
+ <li><a href="https://huggingface.co/spaces/AP123/IllusionDiffusion">https://huggingface.co/spaces/AP123/IllusionDiffusion</a> - use it for GPU speed!</li>
249
  </ul>
250
  </div>
251
  '''