img2img?

#1
by nagolinc - opened

How do I do img2img with this model?

I naively tried this, but it didn't work, obviously.

from diffusers import StableDiffusionImg2ImgPipeline


img2img = StableDiffusionImg2ImgPipeline.from_pretrained('Deci/DeciDiffusion-v1-0',
                                                   custom_pipeline='Deci/DeciDiffusion-v1-0',
                                                   torch_dtype=torch.float16
                                                   )

img2img.unet = img2img.unet.from_pretrained('Deci/DeciDiffusion-v1-0',
                                              subfolder='flexible_unet',
                                              torch_dtype=torch.float16)

# Move pipeline to device
img2img = img2img.to('cuda')

Ah, turns out I can just edit DeciDiffusionPipeline

class DeciDiffusionPipeline_img2img(StableDiffusionImg2ImgPipeline):
    deci_default_number_of_iterations = 30
    deci_default_guidance_rescale = 0.7

    def __init__(self,
                 vae: AutoencoderKL,
                 text_encoder: CLIPTextModel,
                 tokenizer: CLIPTokenizer,
                 unet: UNet2DConditionModel,
                 scheduler: KarrasDiffusionSchedulers,
                 safety_checker: StableDiffusionSafetyChecker,
                 feature_extractor: CLIPImageProcessor,
                 requires_safety_checker: bool = True
                 ):
        # Replace UNet with Deci`s unet
        del unet
        unet = FlexibleUNet2DConditionModel()

        super().__init__(vae=vae,
                         text_encoder=text_encoder,
                         tokenizer=tokenizer,
                         unet=unet,
                         scheduler=scheduler,
                         safety_checker=safety_checker,
                         feature_extractor=feature_extractor,
                         requires_safety_checker=requires_safety_checker
                         )

        self.register_modules(vae=vae,
                              text_encoder=text_encoder,
                              tokenizer=tokenizer,
                              unet=unet,
                              scheduler=scheduler,
                              safety_checker=safety_checker,
                              feature_extractor=feature_extractor)

    def __call__(self, *args, **kwargs):
        # Set up default training parameters (if not given by user specifically)
        #if "guidance_rescale" not in kwargs:
        #    kwargs.update({'guidance_rescale': self.deci_default_guidance_rescale})
        if "num_inference_steps" not in kwargs:
            kwargs.update({'num_inference_steps': self.deci_default_number_of_iterations})
        return super().__call__(*args, **kwargs)
nagolinc changed discussion status to closed

@nagolinc - I will give your implementation a try! Thank you for sharing it

Sign up or log in to comment