Linoy Tsaban commited on
Commit
eec6d5e
1 Parent(s): 2aeea2e

Update modified_pipeline_semantic_stable_diffusion.py

Browse files
modified_pipeline_semantic_stable_diffusion.py CHANGED
@@ -721,37 +721,37 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
721
  callback(i, t, latents)
722
 
723
 
724
- # # 8. Post-processing
725
- # image = self.decode_latents(latents)
726
 
727
- # # 9. Run safety checker
728
- # image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
729
 
730
- # # 10. Convert to PIL
731
- # if output_type == "pil":
732
- # image = self.numpy_to_pil(image)
733
 
734
- # if not return_dict:
735
- # return (image, has_nsfw_concept)
736
 
737
- #return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
738
 
739
- # 8. Post-processing
740
- if not output_type == "latent":
741
- image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
742
- image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
743
- else:
744
- image = latents
745
- has_nsfw_concept = None
746
 
747
- if has_nsfw_concept is None:
748
- do_denormalize = [True] * image.shape[0]
749
- else:
750
- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
751
 
752
- image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
753
 
754
- if not return_dict:
755
- return (image, has_nsfw_concept)
756
 
757
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
 
721
  callback(i, t, latents)
722
 
723
 
724
+ # 8. Post-processing
725
+ image = self.decode_latents(latents)
726
 
727
+ # 9. Run safety checker
728
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
729
 
730
+ # 10. Convert to PIL
731
+ if output_type == "pil":
732
+ image = self.numpy_to_pil(image)
733
 
734
+ if not return_dict:
735
+ return (image, has_nsfw_concept)
736
 
737
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
738
 
739
+ # # 8. Post-processing
740
+ # if not output_type == "latent":
741
+ # image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
742
+ # image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
743
+ # else:
744
+ # image = latents
745
+ # has_nsfw_concept = None
746
 
747
+ # if has_nsfw_concept is None:
748
+ # do_denormalize = [True] * image.shape[0]
749
+ # else:
750
+ # do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
751
 
752
+ # image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
753
 
754
+ # if not return_dict:
755
+ # return (image, has_nsfw_concept)
756
 
757
+ # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)