| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| | |
| | class ResidualBlock(nn.Module): |
| | def __init__(self, channels): |
| | super().__init__() |
| | self.block = nn.Sequential( |
| | nn.Conv2d(channels, channels, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(channels), |
| | nn.ReLU(), |
| | nn.Conv2d(channels, channels, kernel_size=3, padding=1), |
| | nn.BatchNorm2d(channels) |
| | ) |
| |
|
| | def forward(self, x): |
| | return F.relu(x + self.block(x)) |
| |
|
| | |
| | class DepthSTAR(nn.Module): |
| | def __init__( |
| | self, |
| | use_residual_blocks=True, |
| | use_transformer=True, |
| | transformer_layers=8, |
| | transformer_heads=8, |
| | embed_dim=512, |
| | ): |
| | super().__init__() |
| | self.use_residual_blocks = use_residual_blocks |
| | self.use_transformer = use_transformer |
| |
|
| | encoder_layers = [ |
| | nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1), |
| | nn.ReLU(), |
| | nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), |
| | nn.ReLU() |
| | ] |
| | if use_residual_blocks: |
| | encoder_layers.append(ResidualBlock(128)) |
| | encoder_layers += [ |
| | nn.Conv2d(128, embed_dim, kernel_size=3, stride=2, padding=1), |
| | nn.ReLU() |
| | ] |
| | if use_residual_blocks: |
| | encoder_layers.append(ResidualBlock(embed_dim)) |
| |
|
| | self.encoder = nn.Sequential(*encoder_layers) |
| |
|
| | if use_transformer: |
| | self.bottleneck = nn.TransformerEncoder( |
| | nn.TransformerEncoderLayer( |
| | d_model=embed_dim, |
| | nhead=transformer_heads, |
| | dim_feedforward=embed_dim * 4, |
| | batch_first=True |
| | ), |
| | num_layers=transformer_layers |
| | ) |
| |
|
| | decoder_layers = [ |
| | nn.ConvTranspose2d(embed_dim, 128, kernel_size=4, stride=2, padding=1), |
| | nn.ReLU() |
| | ] |
| | if use_residual_blocks: |
| | decoder_layers.append(ResidualBlock(128)) |
| | decoder_layers += [ |
| | nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1), |
| | nn.ReLU(), |
| | nn.Conv2d(64, 1, kernel_size=3, padding=1), |
| | nn.Sigmoid() |
| | ] |
| | self.decoder = nn.Sequential(*decoder_layers) |
| |
|
| | def forward(self, x): |
| | B = x.size(0) |
| | feat = self.encoder(x) |
| | if self.use_transformer: |
| | tokens = feat.flatten(2).transpose(1, 2) |
| | tokens = self.bottleneck(tokens) |
| | feat = tokens.transpose(1, 2).reshape(B, feat.shape[1], feat.shape[2], feat.shape[3]) |
| | return self.decoder(feat) |
| |
|