controlnet / SyncDreamer /ldm /models /diffusion /sync_dreamer_network.py
nekoshadow's picture
Fix import issue
6ad2dc4
import torch
import torch.nn as nn
class Image2DResBlockWithTV(nn.Module):
def __init__(self, dim, tdim, vdim):
super().__init__()
norm = lambda c: nn.GroupNorm(8, c)
self.time_embed = nn.Conv2d(tdim, dim, 1, 1)
self.view_embed = nn.Conv2d(vdim, dim, 1, 1)
self.conv = nn.Sequential(
norm(dim),
nn.SiLU(True),
nn.Conv2d(dim, dim, 3, 1, 1),
norm(dim),
nn.SiLU(True),
nn.Conv2d(dim, dim, 3, 1, 1),
)
def forward(self, x, t, v):
return x+self.conv(x+self.time_embed(t)+self.view_embed(v))
class NoisyTargetViewEncoder(nn.Module):
def __init__(self, time_embed_dim, viewpoint_dim, run_dim=16, output_dim=8):
super().__init__()
self.init_conv = nn.Conv2d(4, run_dim, 3, 1, 1)
self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
self.final_out = nn.Sequential(
nn.GroupNorm(8, run_dim),
nn.SiLU(True),
nn.Conv2d(run_dim, output_dim, 3, 1, 1)
)
def forward(self, x, t, v):
B, DT = t.shape
t = t.view(B, DT, 1, 1)
B, DV = v.shape
v = v.view(B, DV, 1, 1)
x = self.init_conv(x)
x = self.out_conv0(x, t, v)
x = self.out_conv1(x, t, v)
x = self.out_conv2(x, t, v)
x = self.final_out(x)
return x
class SpatialUpTimeBlock(nn.Module):
def __init__(self, x_in_dim, t_in_dim, out_dim):
super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c)
self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16
self.norm = norm_act(x_in_dim)
self.silu = nn.SiLU(True)
self.conv = nn.ConvTranspose3d(x_in_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2)
def forward(self, x, t):
x = x + self.t_conv(t)
return self.conv(self.silu(self.norm(x)))
class SpatialTimeBlock(nn.Module):
def __init__(self, x_in_dim, t_in_dim, out_dim, stride):
super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c)
self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16
self.bn = norm_act(x_in_dim)
self.silu = nn.SiLU(True)
self.conv = nn.Conv3d(x_in_dim, out_dim, 3, stride=stride, padding=1)
def forward(self, x, t):
x = x + self.t_conv(t)
return self.conv(self.silu(self.bn(x)))
class SpatialTime3DNet(nn.Module):
def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)):
super().__init__()
d0, d1, d2, d3 = dims
dt = time_dim
self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1) # 32
self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1)
self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2)
self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1)
self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1)
self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2)
self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1)
self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1)
self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2)
self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1)
self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1)
self.conv7 = SpatialUpTimeBlock(d3, dt, d2)
self.conv8 = SpatialUpTimeBlock(d2, dt, d1)
self.conv9 = SpatialUpTimeBlock(d1, dt, d0)
def forward(self, x, t):
B, C = t.shape
t = t.view(B, C, 1, 1, 1)
x = self.init_conv(x)
conv0 = self.conv0(x, t)
x = self.conv1(conv0, t)
x = self.conv2_0(x, t)
conv2 = self.conv2_1(x, t)
x = self.conv3(conv2, t)
x = self.conv4_0(x, t)
conv4 = self.conv4_1(x, t)
x = self.conv5(conv4, t)
x = self.conv6_0(x, t)
x = self.conv6_1(x, t)
x = conv4 + self.conv7(x, t)
x = conv2 + self.conv8(x, t)
x = conv0 + self.conv9(x, t)
return x
class FrustumTVBlock(nn.Module):
def __init__(self, x_dim, t_dim, v_dim, out_dim, stride):
super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c)
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
self.bn = norm_act(x_dim)
self.silu = nn.SiLU(True)
self.conv = nn.Conv3d(x_dim, out_dim, 3, stride=stride, padding=1)
def forward(self, x, t, v):
x = x + self.t_conv(t) + self.v_conv(v)
return self.conv(self.silu(self.bn(x)))
class FrustumTVUpBlock(nn.Module):
def __init__(self, x_dim, t_dim, v_dim, out_dim):
super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c)
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
self.norm = norm_act(x_dim)
self.silu = nn.SiLU(True)
self.conv = nn.ConvTranspose3d(x_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2)
def forward(self, x, t, v):
x = x + self.t_conv(t) + self.v_conv(v)
return self.conv(self.silu(self.norm(x)))
class FrustumTV3DNet(nn.Module):
def __init__(self, in_dim, t_dim, v_dim, dims=(32, 64, 128, 256)):
super().__init__()
self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32
self.conv1 = FrustumTVBlock(dims[0], t_dim, v_dim, dims[1], 2)
self.conv2 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[1], 1)
self.conv3 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[2], 2)
self.conv4 = FrustumTVBlock(dims[2], t_dim, v_dim, dims[2], 1)
self.conv5 = FrustumTVBlock(dims[2], t_dim, v_dim, dims[3], 2)
self.conv6 = FrustumTVBlock(dims[3], t_dim, v_dim, dims[3], 1)
self.up0 = FrustumTVUpBlock(dims[3], t_dim, v_dim, dims[2])
self.up1 = FrustumTVUpBlock(dims[2], t_dim, v_dim, dims[1])
self.up2 = FrustumTVUpBlock(dims[1], t_dim, v_dim, dims[0])
def forward(self, x, t, v):
B,DT = t.shape
t = t.view(B,DT,1,1,1)
B,DV = v.shape
v = v.view(B,DV,1,1,1)
b, _, d, h, w = x.shape
x0 = self.conv0(x)
x1 = self.conv2(self.conv1(x0, t, v), t, v)
x2 = self.conv4(self.conv3(x1, t, v), t, v)
x3 = self.conv6(self.conv5(x2, t, v), t, v)
x2 = self.up0(x3, t, v) + x2
x1 = self.up1(x2, t, v) + x1
x0 = self.up2(x1, t, v) + x0
return {w: x0, w//2: x1, w//4: x2, w//8: x3}