PeterL1n commited on
Commit
5ecb4a9
1 Parent(s): 59f3984

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -10
app.py CHANGED
@@ -28,28 +28,24 @@ def generate(prompt, option, progress=gr.Progress()):
28
 
29
  # Main pipeline.
30
  unet = UNet2DConditionModel.from_config(base, subfolder="unet")
31
- unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0]))).to(device, dtype)
32
  pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16")
33
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing").to(device, dtype)
 
34
 
35
  # Safety checker.
36
  safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device, dtype)
37
  feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
38
  image_processor = VaeImageProcessor(vae_scale_factor=8)
39
 
40
- progress((0, step))
41
-
42
- if step != step_loaded:
43
- print(f"Switching checkpoint from {step_loaded} to {step}")
44
- pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
45
- pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
46
- step_loaded = step
47
  def inference_callback(p, i, t, kwargs):
48
  progress((i+1, step))
49
  return kwargs
 
 
 
50
  results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback, output_type="pt")
51
 
52
- # Safety check
53
  feature_extractor_input = image_processor.postprocess(results.images, output_type="pil")
54
  safety_checker_input = feature_extractor(feature_extractor_input, return_tensors="pt")
55
  images, has_nsfw_concept = safety_checker(
 
28
 
29
  # Main pipeline.
30
  unet = UNet2DConditionModel.from_config(base, subfolder="unet")
 
31
  pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16")
32
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
33
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
34
 
35
  # Safety checker.
36
  safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device, dtype)
37
  feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
38
  image_processor = VaeImageProcessor(vae_scale_factor=8)
39
 
 
 
 
 
 
 
 
40
  def inference_callback(p, i, t, kwargs):
41
  progress((i+1, step))
42
  return kwargs
43
+
44
+ # Inference loop.
45
+ progress((0, step))
46
  results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback, output_type="pt")
47
 
48
+ # Safety check.
49
  feature_extractor_input = image_processor.postprocess(results.images, output_type="pil")
50
  safety_checker_input = feature_extractor(feature_extractor_input, return_tensors="pt")
51
  images, has_nsfw_concept = safety_checker(