from diffusers import ( DDIMScheduler, DDPMScheduler, DEISMultistepScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, ) SCHEDULER_MAPPING = { "DDIM": DDIMScheduler, "DDPMScheduler": DDPMScheduler, "DEISMultistep": DEISMultistepScheduler, "DPMSolverMultistep": DPMSolverMultistepScheduler, "DPMSolverSinglestep": DPMSolverSinglestepScheduler, "EulerAncestralDiscrete": EulerAncestralDiscreteScheduler, "EulerDiscrete": EulerDiscreteScheduler, "HeunDiscrete": HeunDiscreteScheduler, "KDPM2AncestralDiscrete": KDPM2AncestralDiscreteScheduler, "KDPM2Discrete": KDPM2DiscreteScheduler, "PNDMScheduler": PNDMScheduler, "UniPCMultistep": UniPCMultistepScheduler, } def get_scheduler(pipe, scheduler): if scheduler in SCHEDULER_MAPPING: SchedulerClass = SCHEDULER_MAPPING[scheduler] pipe.scheduler = SchedulerClass.from_config(pipe.scheduler.config) else: raise ValueError(f"Invalid scheduler name {scheduler}") return pipe