Spaces:
Sleeping
Sleeping
File size: 5,738 Bytes
98a77e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from typing import Union, List, Tuple
import os
import video3d.utils.misc as misc
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class Transformer_layer(nn.Module):
def __init__(self, dim_feat=384, dim=1024, hidden_dim=1024, heads=16):
super().__init__()
'''
dim: the dim between each attention, mlp, also the input and output dim for the layer
hidden_dim: the dim inside qkv
dim_feat: condition feature dim
'''
dim_head = hidden_dim // heads
self.heads = heads
self.scale = dim_head ** -0.5 # 8
self.norm = nn.LayerNorm(dim)
self.ffn = FeedForward(
dim=dim,
hidden_dim=(4 * dim),
dropout=0.
)
# cross attention part
self.to_cross_q = nn.Linear(dim, hidden_dim, bias=False)
self.to_cross_kv = nn.Linear(dim_feat, hidden_dim*2, bias=False)
self.cross_attend = nn.Softmax(dim=-1)
# self attention part
self.to_self_qkv = nn.Linear(dim, hidden_dim*3, bias=False)
self.self_attend = nn.Softmax(dim=-1)
def forward_cross_attn(self, x, feature):
x = self.norm(x)
q = self.to_cross_q(x)
k, v = self.to_cross_kv(feature).chunk(2, dim=-1)
qkv = (q, k, v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.cross_attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return out
def forward_self_attn(self, x):
x = self.norm(x)
qkv = self.to_self_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.self_attend(dots)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return out
def forward(self, x, feature):
'''
x: [B, N, dim]
feature: [B, N, dim_feat]
'''
cross_token = self.forward_cross_attn(x, feature)
cross_token = cross_token + x
self_token = self.forward_self_attn(cross_token)
self_token = self_token + cross_token
out = self.ffn(self_token)
out = out + self_token
return out
class Triplane_Transformer(nn.Module):
def __init__(self, emb_dim=1024, emb_num=1024, num_layers=16,
triplane_dim=80, triplane_scale=7.):
super().__init__()
self.learnable_embedding = nn.Parameter(torch.randn(1, emb_num, emb_dim))
self.layers = nn.ModuleList([])
for _ in range(num_layers):
self.layers.append(
Transformer_layer(
dim_feat=384,
dim=emb_dim,
hidden_dim=emb_dim
)
)
self.triplane_dim = triplane_dim
self.triplane_scale = triplane_scale
self.to_triplane = nn.ConvTranspose2d(
in_channels=emb_dim,
out_channels=3 * triplane_dim,
kernel_size=4,
padding=1,
stride=2
)
self.norm = nn.LayerNorm(emb_dim)
def sample_feat(self, feat_maps, pts):
'''
feat_maps: [B, 3, C, H, W]
pts: [B, K, 3]
'''
pts = pts / (self.triplane_scale / 2)
pts_xy = pts[..., [0,1]]
pts_yz = pts[..., [1,2]]
pts_xz = pts[..., [0,2]]
feat_xy = feat_maps[:, 0, :, :, :]
feat_yz = feat_maps[:, 1, :, :, :]
feat_xz = feat_maps[:, 2, :, :, :]
sampled_feat_xy = F.grid_sample(
feat_xy, pts_xy.unsqueeze(1), mode='bilinear', align_corners=True
)
sampled_feat_yz = F.grid_sample(
feat_yz, pts_yz.unsqueeze(1), mode='bilinear', align_corners=True
)
sampled_feat_xz = F.grid_sample(
feat_xz, pts_xz.unsqueeze(1), mode='bilinear', align_corners=True
)
sampled_feat = torch.cat([sampled_feat_xy, sampled_feat_yz, sampled_feat_xz], dim=1).squeeze(-2) # [B, F, K]
sampled_feat = sampled_feat.permute(0, 2, 1)
return sampled_feat
def forward(self, feature, pts):
'''
feature: [B, N, dim_feat]
'''
batch_size = feature.shape[0]
embedding = self.learnable_embedding.repeat(batch_size, 1, 1)
x = embedding
for layer in self.layers:
x = layer(x, feature)
x = self.norm(x)
# x: [B, 32x32, 1024]
batch_size, pwph, feat_dim = x.shape
ph = int(pwph ** 0.5)
pw = int(pwph ** 0.5)
triplane_feat = x.reshape(batch_size, ph, pw, feat_dim).permute(0, 3, 1, 2)
triplane_feat = self.to_triplane(triplane_feat) # [B, C, 64, 64]
triplane_feat = triplane_feat.reshape(triplane_feat.shape[0], 3, self.triplane_dim, triplane_feat.shape[-2], triplane_feat.shape[-1])
pts_feat = self.sample_feat(triplane_feat, pts)
return pts_feat
|