File size: 1,035 Bytes
4d4dd90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from lightglue import LightGlue as LightGlue_
from omegaconf import OmegaConf

from ..base_model import BaseModel


class LightGlue(BaseModel):
    default_conf = {"features": "superpoint", **LightGlue_.default_conf}
    required_data_keys = [
        "view0",
        "keypoints0",
        "descriptors0",
        "view1",
        "keypoints1",
        "descriptors1",
    ]

    def _init(self, conf):
        dconf = OmegaConf.to_container(conf)
        self.net = LightGlue_(dconf.pop("features"), **dconf)
        self.set_initialized()

    def _forward(self, data):
        required_keys = ["keypoints", "descriptors", "scales", "oris"]
        view0 = {
            **data["view0"],
            **{k: data[k + "0"] for k in required_keys if (k + "0") in data},
        }
        view1 = {
            **data["view1"],
            **{k: data[k + "1"] for k in required_keys if (k + "1") in data},
        }
        return self.net({"image0": view0, "image1": view1})

    def loss(pred, data):
        raise NotImplementedError