import torch from torch import nn from hydra.utils import instantiate from omegaconf import OmegaConf from huggingface_hub import PyTorchModelHubMixin class Geolocalizer(nn.Module, PyTorchModelHubMixin): def __init__(self, config): super().__init__() self.config = OmegaConf.create(config) self.transform = instantiate(self.config.transform) self.model = instantiate(self.config.model) self.head = self.model.head self.mid = self.model.mid self.backbone = self.model.backbone def forward(self, img: torch.Tensor): output = self.head(self.mid(self.backbone({"img": img})), None) return output["gps"] def forward_tensor(self, img: torch.Tensor): output = self.head(self.mid(self.backbone(img)), None) return output["gps"]