Spaces:
Sleeping
Sleeping
import os, sys, pdb | |
import diffusers | |
from transformers import AutoTokenizer, PretrainedConfig | |
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler | |
def make_1step_sched(): | |
noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler") | |
noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler") | |
noise_scheduler_1step.set_timesteps(1, device="cuda") | |
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda() | |
return noise_scheduler_1step | |
"""The forward method of the `Encoder` class.""" | |
def my_vae_encoder_fwd(self, sample): | |
sample = self.conv_in(sample) | |
l_blocks = [] | |
# down | |
for down_block in self.down_blocks: | |
l_blocks.append(sample) | |
sample = down_block(sample) | |
# middle | |
sample = self.mid_block(sample) | |
sample = self.conv_norm_out(sample) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample) | |
self.current_down_blocks = l_blocks | |
return sample | |
"""The forward method of the `Decoder` class.""" | |
def my_vae_decoder_fwd(self,sample, latent_embeds = None): | |
sample = self.conv_in(sample) | |
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype | |
# middle | |
sample = self.mid_block(sample, latent_embeds) | |
sample = sample.to(upscale_dtype) | |
if not self.ignore_skip: | |
skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4] | |
# up | |
for idx, up_block in enumerate(self.up_blocks): | |
skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma) | |
# add skip | |
sample = sample + skip_in | |
sample = up_block(sample, latent_embeds) | |
else: | |
for idx, up_block in enumerate(self.up_blocks): | |
sample = up_block(sample, latent_embeds) | |
# post-process | |
if latent_embeds is None: | |
sample = self.conv_norm_out(sample) | |
else: | |
sample = self.conv_norm_out(sample, latent_embeds) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample) | |
return sample | |