Spaces:
Running
Running
File size: 2,470 Bytes
4d4dd90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import torch.nn.functional as F
from omegaconf import OmegaConf
from .. import get_model
from ..base_model import BaseModel
to_ctr = OmegaConf.to_container # convert DictConfig to dict
class MixedExtractor(BaseModel):
default_conf = {
"detector": {"name": None},
"descriptor": {"name": None},
"interpolate_descriptors_from": None, # field name
}
required_data_keys = ["image"]
required_cache_keys = []
def _init(self, conf):
if conf.detector.name:
self.detector = get_model(conf.detector.name)(to_ctr(conf.detector))
else:
self.required_data_keys += ["cache"]
self.required_cache_keys += ["keypoints"]
if conf.descriptor.name:
self.descriptor = get_model(conf.descriptor.name)(to_ctr(conf.descriptor))
else:
self.required_data_keys += ["cache"]
self.required_cache_keys += ["descriptors"]
def _forward(self, data):
if self.conf.detector.name:
pred = self.detector(data)
else:
pred = data["cache"]
if self.conf.detector.name:
pred = {**pred, **self.descriptor({**pred, **data})}
if self.conf.interpolate_descriptors_from:
h, w = data["image"].shape[-2:]
kpts = pred["keypoints"]
pts = (kpts / kpts.new_tensor([[w, h]]) * 2 - 1)[:, None]
pred["descriptors"] = (
F.grid_sample(
pred[self.conf.interpolate_descriptors_from],
pts,
align_corners=False,
mode="bilinear",
)
.squeeze(-2)
.transpose(-2, -1)
.contiguous()
)
return pred
def loss(self, pred, data):
losses = {}
metrics = {}
total = 0
for k in ["detector", "descriptor"]:
apply = True
if "apply_loss" in self.conf[k].keys():
apply = self.conf[k].apply_loss
if self.conf[k].name and apply:
try:
losses_, metrics_ = getattr(self, k).loss(pred, {**pred, **data})
except NotImplementedError:
continue
losses = {**losses, **losses_}
metrics = {**metrics, **metrics_}
total = losses_["total"] + total
return {**losses, "total": total}, metrics
|