import os, sys, pdb import diffusers from transformers import AutoTokenizer, PretrainedConfig from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler def make_1step_sched(): noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler") noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler") noise_scheduler_1step.set_timesteps(1, device="cuda") noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda() return noise_scheduler_1step """The forward method of the `Encoder` class.""" def my_vae_encoder_fwd(self, sample): sample = self.conv_in(sample) l_blocks = [] # down for down_block in self.down_blocks: l_blocks.append(sample) sample = down_block(sample) # middle sample = self.mid_block(sample) sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) self.current_down_blocks = l_blocks return sample """The forward method of the `Decoder` class.""" def my_vae_decoder_fwd(self,sample, latent_embeds = None): sample = self.conv_in(sample) upscale_dtype = next(iter(self.up_blocks.parameters())).dtype # middle sample = self.mid_block(sample, latent_embeds) sample = sample.to(upscale_dtype) if not self.ignore_skip: skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4] # up for idx, up_block in enumerate(self.up_blocks): skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma) # add skip sample = sample + skip_in sample = up_block(sample, latent_embeds) else: for idx, up_block in enumerate(self.up_blocks): sample = up_block(sample, latent_embeds) # post-process if latent_embeds is None: sample = self.conv_norm_out(sample) else: sample = self.conv_norm_out(sample, latent_embeds) sample = self.conv_act(sample) sample = self.conv_out(sample) return sample