|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
from scipy.stats import zscore |
|
|
|
def edge_str(i, j): |
|
return f'{i}_{j}' |
|
|
|
|
|
def i_j_ij(ij): |
|
|
|
return edge_str(*ij), ij |
|
|
|
|
|
def edge_conf(conf_i, conf_j, edge): |
|
|
|
score = float(conf_i[edge].mean() * conf_j[edge].mean()) |
|
|
|
return score |
|
|
|
|
|
def compute_edge_scores(edges, conf_i, conf_j): |
|
score_dict = {(i, j): edge_conf(conf_i, conf_j, e) for e, (i, j) in edges} |
|
|
|
return score_dict |
|
|
|
def NoGradParamDict(x): |
|
assert isinstance(x, dict) |
|
return nn.ParameterDict(x).requires_grad_(False) |
|
|
|
|
|
def get_imshapes(edges, pred_i, pred_j): |
|
n_imgs = max(max(e) for e in edges) + 1 |
|
imshapes = [None] * n_imgs |
|
for e, (i, j) in enumerate(edges): |
|
shape_i = tuple(pred_i[e].shape[0:2]) |
|
shape_j = tuple(pred_j[e].shape[0:2]) |
|
if imshapes[i]: |
|
assert imshapes[i] == shape_i, f'incorrect shape for image {i}' |
|
if imshapes[j]: |
|
assert imshapes[j] == shape_j, f'incorrect shape for image {j}' |
|
imshapes[i] = shape_i |
|
imshapes[j] = shape_j |
|
return imshapes |
|
|
|
|
|
def get_conf_trf(mode): |
|
if mode == 'log': |
|
def conf_trf(x): return x.log() |
|
elif mode == 'sqrt': |
|
def conf_trf(x): return x.sqrt() |
|
elif mode == 'm1': |
|
def conf_trf(x): return x-1 |
|
elif mode in ('id', 'none'): |
|
def conf_trf(x): return x |
|
else: |
|
raise ValueError(f'bad mode for {mode=}') |
|
return conf_trf |
|
|
|
|
|
def l2_dist(a, b, weight): |
|
return ((a - b).square().sum(dim=-1) * weight) |
|
|
|
|
|
def l1_dist(a, b, weight): |
|
return ((a - b).norm(dim=-1) * weight) |
|
|
|
|
|
ALL_DISTS = dict(l1=l1_dist, l2=l2_dist) |
|
|
|
|
|
def signed_log1p(x): |
|
sign = torch.sign(x) |
|
return sign * torch.log1p(torch.abs(x)) |
|
|
|
|
|
def signed_expm1(x): |
|
sign = torch.sign(x) |
|
return sign * torch.expm1(torch.abs(x)) |
|
|
|
|
|
def cosine_schedule(t, lr_start, lr_end): |
|
assert 0 <= t <= 1 |
|
return lr_end + (lr_start - lr_end) * (1+np.cos(t * np.pi))/2 |
|
|
|
|
|
def linear_schedule(t, lr_start, lr_end): |
|
assert 0 <= t <= 1 |
|
return lr_start + (lr_end - lr_start) * t |
|
|
|
def cycled_linear_schedule(t, lr_start, lr_end, num_cycles=2): |
|
assert 0 <= t <= 1 |
|
cycle_t = t * num_cycles |
|
cycle_t = cycle_t - int(cycle_t) |
|
if t == 1: |
|
cycle_t = 1 |
|
return linear_schedule(cycle_t, lr_start, lr_end) |