Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -28,28 +28,24 @@ def generate(prompt, option, progress=gr.Progress()):
|
|
28 |
|
29 |
# Main pipeline.
|
30 |
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
|
31 |
-
unet.load_state_dict(load_file(hf_hub_download(repo, opts["4 Steps"][0]))).to(device, dtype)
|
32 |
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16")
|
33 |
-
pipe.
|
|
|
34 |
|
35 |
# Safety checker.
|
36 |
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device, dtype)
|
37 |
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
38 |
image_processor = VaeImageProcessor(vae_scale_factor=8)
|
39 |
|
40 |
-
progress((0, step))
|
41 |
-
|
42 |
-
if step != step_loaded:
|
43 |
-
print(f"Switching checkpoint from {step_loaded} to {step}")
|
44 |
-
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
|
45 |
-
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
|
46 |
-
step_loaded = step
|
47 |
def inference_callback(p, i, t, kwargs):
|
48 |
progress((i+1, step))
|
49 |
return kwargs
|
|
|
|
|
|
|
50 |
results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback, output_type="pt")
|
51 |
|
52 |
-
# Safety check
|
53 |
feature_extractor_input = image_processor.postprocess(results.images, output_type="pil")
|
54 |
safety_checker_input = feature_extractor(feature_extractor_input, return_tensors="pt")
|
55 |
images, has_nsfw_concept = safety_checker(
|
|
|
28 |
|
29 |
# Main pipeline.
|
30 |
unet = UNet2DConditionModel.from_config(base, subfolder="unet")
|
|
|
31 |
pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=dtype, variant="fp16")
|
32 |
+
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device))
|
33 |
+
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample" if step == 1 else "epsilon")
|
34 |
|
35 |
# Safety checker.
|
36 |
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker").to(device, dtype)
|
37 |
feature_extractor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
38 |
image_processor = VaeImageProcessor(vae_scale_factor=8)
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
def inference_callback(p, i, t, kwargs):
|
41 |
progress((i+1, step))
|
42 |
return kwargs
|
43 |
+
|
44 |
+
# Inference loop.
|
45 |
+
progress((0, step))
|
46 |
results = pipe(prompt, num_inference_steps=step, guidance_scale=0, callback_on_step_end=inference_callback, output_type="pt")
|
47 |
|
48 |
+
# Safety check.
|
49 |
feature_extractor_input = image_processor.postprocess(results.images, output_type="pil")
|
50 |
safety_checker_input = feature_extractor(feature_extractor_input, return_tensors="pt")
|
51 |
images, has_nsfw_concept = safety_checker(
|