guesstimatelocation / models /huggingface.py
yunusserhat's picture
Upload 40 files
94f372a verified
raw
history blame contribute delete
No virus
854 Bytes
import torch
from torch import nn
from hydra.utils import instantiate
from omegaconf import OmegaConf
from huggingface_hub import PyTorchModelHubMixin
class Geolocalizer(nn.Module, PyTorchModelHubMixin):
def __init__(self, config):
super().__init__()
self.config = OmegaConf.create(config)
self.transform = instantiate(self.config.transform)
self.model = instantiate(self.config.model)
self.head = self.model.head
self.mid = self.model.mid
self.backbone = self.model.backbone
def forward(self, img: torch.Tensor):
output = self.head(self.mid(self.backbone({"img": img})), None)
return output["gps"]
def forward_tensor(self, img: torch.Tensor):
output = self.head(self.mid(self.backbone(img)), None)
return output["gps"]