Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import pandas as pd | |
from models.networks.utils import UnormGPS | |
class HybridHead(nn.Module): | |
"""Classification head followed by regression head for the network.""" | |
def __init__(self, final_dim, quadtree_path, use_tanh, scale_tanh): | |
super().__init__() | |
self.final_dim = final_dim | |
self.use_tanh = use_tanh | |
self.scale_tanh = scale_tanh | |
self.unorm = UnormGPS() | |
if quadtree_path is not None: | |
quadtree = pd.read_csv(quadtree_path) | |
self.init_quadtree(quadtree) | |
def init_quadtree(self, quadtree): | |
quadtree[["min_lat", "max_lat"]] /= 90.0 | |
quadtree[["min_lon", "max_lon"]] /= 180.0 | |
self.register_buffer( | |
"cell_center", | |
0.5 * torch.tensor(quadtree[["max_lat", "max_lon"]].values) | |
+ 0.5 * torch.tensor(quadtree[["min_lat", "min_lon"]].values), | |
) | |
self.register_buffer( | |
"cell_size", | |
torch.tensor(quadtree[["max_lat", "max_lon"]].values) | |
- torch.tensor(quadtree[["min_lat", "min_lon"]].values), | |
) | |
def forward(self, x, gt_label): | |
"""Forward pass of the network. | |
x : Union[torch.Tensor, dict] with the output of the backbone. | |
""" | |
classification_logits = x[..., : self.final_dim] | |
classification = classification_logits.argmax(dim=-1) | |
regression = x[..., self.final_dim :] | |
if self.use_tanh: | |
regression = self.scale_tanh * torch.tanh(regression) | |
regression = regression.view(regression.shape[0], -1, 2) | |
if self.training: | |
regression = torch.gather( | |
regression, | |
1, | |
gt_label.unsqueeze(-1).unsqueeze(-1).expand(regression.shape[0], 1, 2), | |
)[:, 0, :] | |
size = 2.0 / self.cell_size[gt_label] | |
center = self.cell_center[gt_label] | |
gps = ( | |
self.cell_center[gt_label] + regression * self.cell_size[gt_label] / 2.0 | |
) | |
else: | |
regression = torch.gather( | |
regression, | |
1, | |
classification.unsqueeze(-1) | |
.unsqueeze(-1) | |
.expand(regression.shape[0], 1, 2), | |
)[:, 0, :] | |
size = 2.0 / self.cell_size[classification] | |
center = self.cell_center[classification] | |
gps = ( | |
self.cell_center[classification] | |
+ regression * self.cell_size[classification] / 2.0 | |
) | |
gps = self.unorm(gps) | |
return { | |
"label": classification_logits, | |
"gps": gps, | |
"size": size, | |
"center": center, | |
"reg": regression, | |
} | |
class HybridHeadCentroid(nn.Module): | |
"""Classification head followed by regression head for the network.""" | |
def __init__(self, final_dim, quadtree_path, use_tanh, scale_tanh): | |
super().__init__() | |
self.final_dim = final_dim | |
self.use_tanh = use_tanh | |
self.scale_tanh = scale_tanh | |
self.unorm = UnormGPS() | |
if quadtree_path is not None: | |
quadtree = pd.read_csv(quadtree_path) | |
self.init_quadtree(quadtree) | |
def init_quadtree(self, quadtree): | |
quadtree[["min_lat", "max_lat", "mean_lat"]] /= 90.0 | |
quadtree[["min_lon", "max_lon", "mean_lon"]] /= 180.0 | |
self.cell_center = torch.tensor(quadtree[["mean_lat", "mean_lon"]].values) | |
self.cell_size_up = torch.tensor(quadtree[["max_lat", "max_lon"]].values) - torch.tensor(quadtree[["mean_lat", "mean_lon"]].values) | |
self.cell_size_down = torch.tensor(quadtree[["mean_lat", "mean_lon"]].values) - torch.tensor(quadtree[["min_lat", "min_lon"]].values) | |
def forward(self, x, gt_label): | |
"""Forward pass of the network. | |
x : Union[torch.Tensor, dict] with the output of the backbone. | |
""" | |
classification_logits = x[..., : self.final_dim] | |
classification = classification_logits.argmax(dim=-1) | |
self.cell_size_up = self.cell_size_up.to(classification.device) | |
self.cell_center = self.cell_center.to(classification.device) | |
self.cell_size_down = self.cell_size_down.to(classification.device) | |
regression = x[..., self.final_dim :] | |
if self.use_tanh: | |
regression = self.scale_tanh * torch.tanh(regression) | |
regression = regression.view(regression.shape[0], -1, 2) | |
if self.training: | |
regression = torch.gather( | |
regression, | |
1, | |
gt_label.unsqueeze(-1).unsqueeze(-1).expand(regression.shape[0], 1, 2), | |
)[:, 0, :] | |
size = torch.where( | |
regression > 0, | |
self.cell_size_up[gt_label], | |
self.cell_size_down[gt_label], | |
) | |
center = self.cell_center[gt_label] | |
gps = self.cell_center[gt_label] + regression * size | |
else: | |
regression = torch.gather( | |
regression, | |
1, | |
classification.unsqueeze(-1) | |
.unsqueeze(-1) | |
.expand(regression.shape[0], 1, 2), | |
)[:, 0, :] | |
size = torch.where( | |
regression > 0, | |
self.cell_size_up[classification], | |
self.cell_size_down[classification], | |
) | |
center = self.cell_center[classification] | |
gps = self.cell_center[classification] + regression * size | |
gps = self.unorm(gps) | |
return { | |
"label": classification_logits, | |
"gps": gps, | |
"size": 1.0 / size, | |
"center": center, | |
"reg": regression, | |
} | |
class SharedHybridHead(HybridHead): | |
"""Classification head followed by SHARED regression head for the network.""" | |
def forward(self, x, gt_label): | |
"""Forward pass of the network. | |
x : Union[torch.Tensor, dict] with the output of the backbone. | |
""" | |
classification_logits = x[..., : self.final_dim] | |
classification = classification_logits.argmax(dim=-1) | |
regression = x[..., self.final_dim :] | |
if self.use_tanh: | |
regression = self.scale_tanh * torch.tanh(regression) | |
if self.training: | |
gps = ( | |
self.cell_center[gt_label] + regression * self.cell_size[gt_label] / 2.0 | |
) | |
else: | |
gps = ( | |
self.cell_center[classification] | |
+ regression * self.cell_size[classification] / 2.0 | |
) | |
gps = self.unorm(gps) | |
return {"label": classification_logits, "gps": gps} | |