Spaces:
Sleeping
Sleeping
| import torch | |
| from diffusers.models import AutoencoderKL | |
| from torch import nn | |
| from torchvision.transforms import ToPILImage | |
| from torch import Tensor | |
| class SynchronizedGroupNorm(nn.Module): | |
| def __init__(self, original_group_norm: nn.GroupNorm, num_views: int = 6): | |
| super().__init__() | |
| self.num_views = num_views | |
| # 继承原始分组参数 | |
| self.num_groups = original_group_norm.num_groups | |
| self.num_channels = original_group_norm.num_channels | |
| self.eps = original_group_norm.eps | |
| # 按照通道组重构参数 | |
| self.group_size = self.num_channels // self.num_groups | |
| # 继承原始参数(保持每个分组的仿射变换) | |
| self.weight = nn.Parameter(original_group_norm.weight.detach().view(self.num_groups, self.group_size)) | |
| self.bias = nn.Parameter(original_group_norm.bias.detach().view(self.num_groups, self.group_size)) | |
| def forward(self, x: torch.Tensor): | |
| """ 兼容 3D (B, C, D) 和 4D (B, C, H, W) 输入 """ | |
| #print(f"Input shape: {x.shape}") # Debugging | |
| # 获取输入的形状信息 | |
| BxT, C = x.shape[:2] # 只取前两个维度 | |
| B = BxT // self.num_views # 计算 batch 维度 | |
| # 处理 3D (B, C, D) 输入 | |
| if x.dim() == 3: | |
| D = x.shape[2] | |
| x = x.view(B, self.num_views, self.num_groups, self.group_size, D) | |
| # 计算 GroupNorm | |
| mean = x.mean(dim=(1, 3, 4), keepdim=True) | |
| var = x.var(dim=(1, 3, 4), keepdim=True, unbiased=False) | |
| x = (x - mean) / torch.sqrt(var + self.eps) | |
| # **修正 weight 和 bias 的形状** | |
| weight = self.weight.view(1, self.num_groups, self.group_size, 1) | |
| bias = self.bias.view(1, self.num_groups, self.group_size, 1) | |
| x = x * weight + bias | |
| # 还原形状 | |
| return x.view(BxT, C, D) | |
| # 处理 4D (B, C, H, W) 输入 | |
| elif x.dim() == 4: | |
| H, W = x.shape[2:] | |
| x = x.view(B, self.num_views, self.num_groups, self.group_size, H, W) | |
| # 计算 GroupNorm | |
| mean = x.mean(dim=(1, 3, 4, 5), keepdim=True) | |
| var = x.var(dim=(1, 3, 4, 5), keepdim=True, unbiased=False) | |
| x = (x - mean) / torch.sqrt(var + self.eps) | |
| # **修正 weight 和 bias 的形状** | |
| weight = self.weight.view(1, self.num_groups, self.group_size, 1, 1) | |
| bias = self.bias.view(1, self.num_groups, self.group_size, 1, 1) | |
| x = x * weight + bias | |
| # 还原形状 | |
| return x.view(BxT, C, H, W) | |
| else: | |
| raise ValueError(f"Unsupported input shape: {x.shape}, expected 3D (B, C, D) or 4D (B, C, H, W).") | |
| class CubemapVAE(AutoencoderKL): | |
| def __init__(self, pretrained_vae, num_views=6, in_channels=3,image_size=512): | |
| super().__init__( # 继承自 AutoencoderKL | |
| act_fn="silu", | |
| block_out_channels=[128, 256, 512, 512], | |
| down_block_types=[ | |
| "DownEncoderBlock2D", | |
| "DownEncoderBlock2D", | |
| "DownEncoderBlock2D", | |
| "DownEncoderBlock2D" | |
| ], | |
| up_block_types=[ | |
| "UpDecoderBlock2D", | |
| "UpDecoderBlock2D", | |
| "UpDecoderBlock2D", | |
| "UpDecoderBlock2D" | |
| ], | |
| latent_channels=pretrained_vae.config.latent_channels, | |
| in_channels=in_channels, | |
| out_channels=in_channels | |
| ) | |
| self.num_views = num_views | |
| self.in_channels = in_channels | |
| # --- 替换关键模块,适配 Cubemap --- | |
| # 原 AutoencoderKL 的编码器不够灵活,直接覆盖编码器 | |
| #self.encoder = CubemapEncoder(pretrained_encoder=pretrained_vae.encoder,num_views=num_views, in_channels=in_channels) | |
| #self.decoder = CubemapDecoder(pretrained_decoder=pretrained_vae.decoder, num_views=num_views, out_channels=in_channels,in_channels=4) | |
| self.encoder=pretrained_vae.encoder | |
| self.decoder=pretrained_vae.decoder | |
| self.quant_conv=pretrained_vae.quant_conv | |
| self.post_quant_conv=pretrained_vae.post_quant_conv | |
| # 将原 GroupNorm 替换为同步 GroupNorm | |
| replace_group_norm_with_sgn(self, num_views=num_views) | |
| def encode(self, images,return_dict:bool=True): | |
| batch_size, num_views, num_channels, height, width = images.shape | |
| images = images.view(batch_size*num_views,num_channels, height, width) | |
| return super().encode(images,return_dict=return_dict) | |
| def decode(self, latents, return_dict=True, **kwargs): | |
| """ | |
| 自定义 VAE 解码: | |
| - 去掉 UV 通道 (只保留前 4 个 latent 通道) | |
| - 调用原始 VAE 解码流程 | |
| """ | |
| print("Decoder Recieve Latent Shape:", latents.shape) | |
| # 确保 latents 至少有 4 个通道 | |
| if latents.shape[1] > 4: | |
| latents = latents[:, :4, :, :] # 只保留前 4 个通道,去掉 UV 通道 | |
| return super().decode(latents, return_dict=return_dict, **kwargs) | |
| def decode_to_tensor(self, latents): | |
| decoded = self.decode(latents).sample # (B*6, 3, H, W) | |
| B = latents.shape[0] // 6 | |
| images = torch.split(decoded, B, dim=0) # 按 batch 拆分 | |
| return images # Tuple of 6 tensors | |
| def decode_to_pil_images(self, latents:Tensor): | |
| images = self.decode_to_tensor(latents) # 获取 6 张图 | |
| to_pil = ToPILImage() | |
| return [to_pil(img[0].cpu().detach()) for img in images] # 转换为 PIL | |
| def replace_group_norm_with_sgn(model, num_views): | |
| """ 遍历 model,找到所有 GroupNorm 并替换成 SynchronizedGroupNorm """ | |
| replacements = [] # 先收集要替换的 module 名称 | |
| for name, module in model.named_modules(): | |
| if isinstance(module, nn.GroupNorm): | |
| replacements.append(name) | |
| for name in replacements: | |
| parent_module, attr_name = get_parent_module(model, name) | |
| setattr(parent_module, attr_name, SynchronizedGroupNorm(getattr(parent_module, attr_name), num_views)) | |
| def get_parent_module(model, module_name): | |
| """ 获取 `module_name` 所在的上一级 module 和属性名称 """ | |
| names = module_name.split(".") | |
| parent_module = model | |
| for name in names[:-1]: # 遍历到倒数第二层 | |
| parent_module = getattr(parent_module, name) | |
| return parent_module, names[-1] | |
| def flatten_face_names(face_names): | |
| flat_face_names = [] | |
| for item in face_names: | |
| if isinstance(item, str): # 直接是字符串 | |
| flat_face_names.append(item) | |
| elif isinstance(item, (list,tuple)): # 是列表,展开其中的字符串 | |
| flat_face_names.extend(item) | |
| else: | |
| raise ValueError(f"Unexpected type in face_names: {type(item)}") | |
| return flat_face_names | |