SpatialDiffusion / scripts /cubemap_vae.py
zimhe
add scripts and examples
a521a3f
import torch
from diffusers.models import AutoencoderKL
from torch import nn
from torchvision.transforms import ToPILImage
from torch import Tensor
class SynchronizedGroupNorm(nn.Module):
def __init__(self, original_group_norm: nn.GroupNorm, num_views: int = 6):
super().__init__()
self.num_views = num_views
# 继承原始分组参数
self.num_groups = original_group_norm.num_groups
self.num_channels = original_group_norm.num_channels
self.eps = original_group_norm.eps
# 按照通道组重构参数
self.group_size = self.num_channels // self.num_groups
# 继承原始参数(保持每个分组的仿射变换)
self.weight = nn.Parameter(original_group_norm.weight.detach().view(self.num_groups, self.group_size))
self.bias = nn.Parameter(original_group_norm.bias.detach().view(self.num_groups, self.group_size))
def forward(self, x: torch.Tensor):
""" 兼容 3D (B, C, D) 和 4D (B, C, H, W) 输入 """
#print(f"Input shape: {x.shape}") # Debugging
# 获取输入的形状信息
BxT, C = x.shape[:2] # 只取前两个维度
B = BxT // self.num_views # 计算 batch 维度
# 处理 3D (B, C, D) 输入
if x.dim() == 3:
D = x.shape[2]
x = x.view(B, self.num_views, self.num_groups, self.group_size, D)
# 计算 GroupNorm
mean = x.mean(dim=(1, 3, 4), keepdim=True)
var = x.var(dim=(1, 3, 4), keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
# **修正 weight 和 bias 的形状**
weight = self.weight.view(1, self.num_groups, self.group_size, 1)
bias = self.bias.view(1, self.num_groups, self.group_size, 1)
x = x * weight + bias
# 还原形状
return x.view(BxT, C, D)
# 处理 4D (B, C, H, W) 输入
elif x.dim() == 4:
H, W = x.shape[2:]
x = x.view(B, self.num_views, self.num_groups, self.group_size, H, W)
# 计算 GroupNorm
mean = x.mean(dim=(1, 3, 4, 5), keepdim=True)
var = x.var(dim=(1, 3, 4, 5), keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
# **修正 weight 和 bias 的形状**
weight = self.weight.view(1, self.num_groups, self.group_size, 1, 1)
bias = self.bias.view(1, self.num_groups, self.group_size, 1, 1)
x = x * weight + bias
# 还原形状
return x.view(BxT, C, H, W)
else:
raise ValueError(f"Unsupported input shape: {x.shape}, expected 3D (B, C, D) or 4D (B, C, H, W).")
class CubemapVAE(AutoencoderKL):
def __init__(self, pretrained_vae, num_views=6, in_channels=3,image_size=512):
super().__init__( # 继承自 AutoencoderKL
act_fn="silu",
block_out_channels=[128, 256, 512, 512],
down_block_types=[
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D",
"DownEncoderBlock2D"
],
up_block_types=[
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D",
"UpDecoderBlock2D"
],
latent_channels=pretrained_vae.config.latent_channels,
in_channels=in_channels,
out_channels=in_channels
)
self.num_views = num_views
self.in_channels = in_channels
# --- 替换关键模块,适配 Cubemap ---
# 原 AutoencoderKL 的编码器不够灵活,直接覆盖编码器
#self.encoder = CubemapEncoder(pretrained_encoder=pretrained_vae.encoder,num_views=num_views, in_channels=in_channels)
#self.decoder = CubemapDecoder(pretrained_decoder=pretrained_vae.decoder, num_views=num_views, out_channels=in_channels,in_channels=4)
self.encoder=pretrained_vae.encoder
self.decoder=pretrained_vae.decoder
self.quant_conv=pretrained_vae.quant_conv
self.post_quant_conv=pretrained_vae.post_quant_conv
# 将原 GroupNorm 替换为同步 GroupNorm
replace_group_norm_with_sgn(self, num_views=num_views)
def encode(self, images,return_dict:bool=True):
batch_size, num_views, num_channels, height, width = images.shape
images = images.view(batch_size*num_views,num_channels, height, width)
return super().encode(images,return_dict=return_dict)
def decode(self, latents, return_dict=True, **kwargs):
"""
自定义 VAE 解码:
- 去掉 UV 通道 (只保留前 4 个 latent 通道)
- 调用原始 VAE 解码流程
"""
print("Decoder Recieve Latent Shape:", latents.shape)
# 确保 latents 至少有 4 个通道
if latents.shape[1] > 4:
latents = latents[:, :4, :, :] # 只保留前 4 个通道,去掉 UV 通道
return super().decode(latents, return_dict=return_dict, **kwargs)
def decode_to_tensor(self, latents):
decoded = self.decode(latents).sample # (B*6, 3, H, W)
B = latents.shape[0] // 6
images = torch.split(decoded, B, dim=0) # 按 batch 拆分
return images # Tuple of 6 tensors
def decode_to_pil_images(self, latents:Tensor):
images = self.decode_to_tensor(latents) # 获取 6 张图
to_pil = ToPILImage()
return [to_pil(img[0].cpu().detach()) for img in images] # 转换为 PIL
def replace_group_norm_with_sgn(model, num_views):
""" 遍历 model,找到所有 GroupNorm 并替换成 SynchronizedGroupNorm """
replacements = [] # 先收集要替换的 module 名称
for name, module in model.named_modules():
if isinstance(module, nn.GroupNorm):
replacements.append(name)
for name in replacements:
parent_module, attr_name = get_parent_module(model, name)
setattr(parent_module, attr_name, SynchronizedGroupNorm(getattr(parent_module, attr_name), num_views))
def get_parent_module(model, module_name):
""" 获取 `module_name` 所在的上一级 module 和属性名称 """
names = module_name.split(".")
parent_module = model
for name in names[:-1]: # 遍历到倒数第二层
parent_module = getattr(parent_module, name)
return parent_module, names[-1]
def flatten_face_names(face_names):
flat_face_names = []
for item in face_names:
if isinstance(item, str): # 直接是字符串
flat_face_names.append(item)
elif isinstance(item, (list,tuple)): # 是列表,展开其中的字符串
flat_face_names.extend(item)
else:
raise ValueError(f"Unexpected type in face_names: {type(item)}")
return flat_face_names