import torch import torch.nn as nn from diffusers import AutoencoderKL, UNet2DConditionModel from diffusers.image_processor import VaeImageProcessor from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution from transformers import PretrainedConfig, PreTrainedModel class SEPath(nn.Module): def __init__(self, in_channels, out_channels, reduction=16): super(SEPath, self).__init__() self.fc = nn.Sequential( nn.Linear(in_channels, in_channels // reduction, bias=False), nn.ReLU(inplace=True), nn.Linear(in_channels // reduction, out_channels, bias=False), nn.Sigmoid() ) def forward(self, in_tensor, out_tensor): B, C, H, W = in_tensor.size() # Squeeze operation x = in_tensor.view(B, C, -1).mean(dim=2) # Excitation operation x = self.fc(x).unsqueeze(2).unsqueeze(2) return out_tensor * x class SeResVaeConfig(PretrainedConfig): model_type = "seresvae" def __init__( self, base_model="stabilityai/stable-diffusion-2-1", height=512, width=512, **kwargs ): self.base_model=base_model self.height=height self.width=width super().__init__(**kwargs) class SeResVaeModel(PreTrainedModel): config_class = SeResVaeConfig def __init__(self, config): super().__init__(config) self.image_processor = VaeImageProcessor() self.vae = AutoencoderKL.from_pretrained(config.base_model, subfolder='vae') self.unet = UNet2DConditionModel.from_pretrained(config.base_model, subfolder='unet') self.se_paths = nn.ModuleList([SEPath(8,4), SEPath(512,512), SEPath(512,512), SEPath(256,512), SEPath(128,256)]) self.prompt_embeds = nn.Parameter(torch.randn(1,77,1024)) self.height=config.height self.width=config.width def forward(self, images_gray, input_type='pil', output_type='pil'): if input_type=='pil': images_gray = self.image_processor.preprocess(images_gray, height=self.height, width=self.width).float() elif input_type=='pt': images_gray=images_gray else: raise ValueError('unsupported input_type') images_gray = images_gray.to(self.vae.device) B, C, H, W = images_gray.shape prompt_embeds = self.prompt_embeds.repeat(B,1,1) posterior, encode_residual = self.encode_with_residual(images_gray) latents = posterior.mode() t = torch.LongTensor([500]).repeat(B).to(self.vae.device) noise_pred = self.unet(latents, t, encoder_hidden_states=prompt_embeds)[0] denoised_latents = latents - noise_pred images_rgb = self.decode_with_residual(denoised_latents, *encode_residual) if output_type=='pil': images_rgb = self.image_processor.postprocess(images_rgb) elif output_type=='np': images_rgb = self.image_processor.postprocess(images_rgb, 'np') elif output_type=='pt': images_rgb = self.image_processor.postprocess(images_rgb, 'pt') elif output_type=='none': images_rgb = images_rgb else: raise ValueError('unsupported output_type') return images_rgb def encode_with_residual(self, sample): re = self.vae.encoder.conv_in(sample) re0, re0_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[0], re) re1, re1_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[1], re0) re2, re2_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[2], re1) re3, re3_out = self._DownEncoderBlock2D_res_forward(self.vae.encoder.down_blocks[3], re2) rem = self.vae.encoder.mid_block(re3) re_out = self.vae.encoder.conv_norm_out(rem) re_out = self.vae.encoder.conv_act(re_out) re_out = self.vae.encoder.conv_out(re_out) re_out = self.vae.quant_conv(re_out) posterior = DiagonalGaussianDistribution(re_out) return posterior, (re0_out, re1_out, re2_out, rem, re_out) def decode_with_residual(self, z, re0_out, re1_out, re2_out, rem, re_out): rd = self.vae.post_quant_conv(self.se_paths[0](re_out, z)) rd = self.vae.decoder.conv_in(rd) rdm = self.vae.decoder.mid_block(self.se_paths[1](rem, rd)).to(torch.float32) rd0 = self.vae.decoder.up_blocks[0](rdm) rd1 = self.vae.decoder.up_blocks[1](self.se_paths[2](re2_out, rd0)) rd2 = self.vae.decoder.up_blocks[2](self.se_paths[3](re1_out, rd1)) rd3 = self.vae.decoder.up_blocks[3](self.se_paths[4](re0_out, rd2)) rd_out = self.vae.decoder.conv_norm_out(rd3) rd_out = self.vae.decoder.conv_act(rd_out) sample_out = self.vae.decoder.conv_out(rd_out) return sample_out def _DownEncoderBlock2D_res_forward(self, down_encoder_block_2d, hidden_states): for resnet in down_encoder_block_2d.resnets: hidden_states = resnet(hidden_states, temb=None) output_states = hidden_states if down_encoder_block_2d.downsamplers is not None: for downsampler in down_encoder_block_2d.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states, output_states