|
|
|
|
|
|
|
from typing import Optional |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
|
|
class FOVNetwork(nn.Module): |
|
"""Field of View estimation network.""" |
|
|
|
def __init__( |
|
self, |
|
num_features: int, |
|
fov_encoder: Optional[nn.Module] = None, |
|
): |
|
"""Initialize the Field of View estimation block. |
|
|
|
Args: |
|
---- |
|
num_features: Number of features used. |
|
fov_encoder: Optional encoder to bring additional network capacity. |
|
|
|
""" |
|
super().__init__() |
|
|
|
|
|
fov_head0 = [ |
|
nn.Conv2d( |
|
num_features, num_features // 2, kernel_size=3, stride=2, padding=1 |
|
), |
|
nn.ReLU(True), |
|
] |
|
fov_head = [ |
|
nn.Conv2d( |
|
num_features // 2, num_features // 4, kernel_size=3, stride=2, padding=1 |
|
), |
|
nn.ReLU(True), |
|
nn.Conv2d( |
|
num_features // 4, num_features // 8, kernel_size=3, stride=2, padding=1 |
|
), |
|
nn.ReLU(True), |
|
nn.Conv2d(num_features // 8, 1, kernel_size=6, stride=1, padding=0), |
|
] |
|
if fov_encoder is not None: |
|
self.encoder = nn.Sequential( |
|
fov_encoder, nn.Linear(fov_encoder.embed_dim, num_features // 2) |
|
) |
|
self.downsample = nn.Sequential(*fov_head0) |
|
else: |
|
fov_head = fov_head0 + fov_head |
|
self.head = nn.Sequential(*fov_head) |
|
|
|
def forward(self, x: torch.Tensor, lowres_feature: torch.Tensor) -> torch.Tensor: |
|
"""Forward the fov network. |
|
|
|
Args: |
|
---- |
|
x (torch.Tensor): Input image. |
|
lowres_feature (torch.Tensor): Low resolution feature. |
|
|
|
Returns: |
|
------- |
|
The field of view tensor. |
|
|
|
""" |
|
if hasattr(self, "encoder"): |
|
x = F.interpolate( |
|
x, |
|
size=None, |
|
scale_factor=0.25, |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
x = self.encoder(x)[:, 1:].permute(0, 2, 1) |
|
lowres_feature = self.downsample(lowres_feature) |
|
x = x.reshape_as(lowres_feature) + lowres_feature |
|
else: |
|
x = lowres_feature |
|
return self.head(x) |
|
|