Realcat
add: GIM (https://github.com/xuelunshen/gim)
4d4dd90
raw
history blame
No virus
4.82 kB
import string
import h5py
import torch
from ..datasets.base_dataset import collate
from ..settings import DATA_PATH
from ..utils.tensor import batch_to_device
from .base_model import BaseModel
from .utils.misc import pad_to_length
def pad_local_features(pred: dict, seq_l: int):
pred["keypoints"] = pad_to_length(
pred["keypoints"],
seq_l,
-2,
mode="random_c",
)
if "keypoint_scores" in pred.keys():
pred["keypoint_scores"] = pad_to_length(
pred["keypoint_scores"], seq_l, -1, mode="zeros"
)
if "descriptors" in pred.keys():
pred["descriptors"] = pad_to_length(
pred["descriptors"], seq_l, -2, mode="random"
)
if "scales" in pred.keys():
pred["scales"] = pad_to_length(pred["scales"], seq_l, -1, mode="zeros")
if "oris" in pred.keys():
pred["oris"] = pad_to_length(pred["oris"], seq_l, -1, mode="zeros")
if "depth_keypoints" in pred.keys():
pred["depth_keypoints"] = pad_to_length(
pred["depth_keypoints"], seq_l, -1, mode="zeros"
)
if "valid_depth_keypoints" in pred.keys():
pred["valid_depth_keypoints"] = pad_to_length(
pred["valid_depth_keypoints"], seq_l, -1, mode="zeros"
)
return pred
def pad_line_features(pred, seq_l: int = None):
raise NotImplementedError
def recursive_load(grp, pkeys):
return {
k: torch.from_numpy(grp[k].__array__())
if isinstance(grp[k], h5py.Dataset)
else recursive_load(grp[k], list(grp.keys()))
for k in pkeys
}
class CacheLoader(BaseModel):
default_conf = {
"path": "???", # can be a format string like exports/{scene}/
"data_keys": None, # load all keys
"device": None, # load to same device as data
"trainable": False,
"add_data_path": True,
"collate": True,
"scale": ["keypoints", "lines", "orig_lines"],
"padding_fn": None,
"padding_length": None, # required for batching!
"numeric_type": "float32", # [None, "float16", "float32", "float64"]
}
required_data_keys = ["name"] # we need an identifier
def _init(self, conf):
self.hfiles = {}
self.padding_fn = conf.padding_fn
if self.padding_fn is not None:
self.padding_fn = eval(self.padding_fn)
self.numeric_dtype = {
None: None,
"float16": torch.float16,
"float32": torch.float32,
"float64": torch.float64,
}[conf.numeric_type]
def _forward(self, data):
preds = []
device = self.conf.device
if not device:
devices = set(
[v.device for v in data.values() if isinstance(v, torch.Tensor)]
)
if len(devices) == 0:
device = "cpu"
else:
assert len(devices) == 1
device = devices.pop()
var_names = [x[1] for x in string.Formatter().parse(self.conf.path) if x[1]]
for i, name in enumerate(data["name"]):
fpath = self.conf.path.format(**{k: data[k][i] for k in var_names})
if self.conf.add_data_path:
fpath = DATA_PATH / fpath
hfile = h5py.File(str(fpath), "r")
grp = hfile[name]
pkeys = (
self.conf.data_keys if self.conf.data_keys is not None else grp.keys()
)
pred = recursive_load(grp, pkeys)
if self.numeric_dtype is not None:
pred = {
k: v
if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v)
else v.to(dtype=self.numeric_dtype)
for k, v in pred.items()
}
pred = batch_to_device(pred, device)
for k, v in pred.items():
for pattern in self.conf.scale:
if k.startswith(pattern):
view_idx = k.replace(pattern, "")
scales = (
data["scales"]
if len(view_idx) == 0
else data[f"view{view_idx}"]["scales"]
)
pred[k] = pred[k] * scales[i]
# use this function to fix number of keypoints etc.
if self.padding_fn is not None:
pred = self.padding_fn(pred, self.conf.padding_length)
preds.append(pred)
hfile.close()
if self.conf.collate:
return batch_to_device(collate(preds), device)
else:
assert len(preds) == 1
return batch_to_device(preds[0], device)
def loss(self, pred, data):
raise NotImplementedError