File size: 1,769 Bytes
5e88f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sys

import torch

from dist import LOGGER


def lstq(A, F_u, F_v, lamda=0.01):
    # cols = A.shape[2]
    # assert all(cols == torch.linalg.matrix_rank(A))  # something better?
    try:
        Q, R = torch.linalg.qr(A)
        theta_x = torch.bmm(torch.bmm(torch.linalg.inv(R), Q.transpose(1, 2)), F_u)
        theta_y = torch.bmm(torch.bmm(torch.linalg.inv(R), Q.transpose(1, 2)), F_v)
    except:
        LOGGER.exception("Least Squares failed")
        sys.exit(-1)
    return theta_x, theta_y

def get_quad_flow(masks_softmaxed, flow, grid_x, grid_y):
    rec_flow = 0
    for k in range(masks_softmaxed.size(1)):
        mask = masks_softmaxed[:, k].unsqueeze(1)
        _F = flow * mask
        M = mask.flatten(1)
        bs = _F.shape[0]
        x = grid_x.unsqueeze(0).flatten(1)
        y = grid_y.unsqueeze(0).flatten(1)

        F_u = _F[:, 0].flatten(1).unsqueeze(2)  # B x L x 1
        F_v = _F[:, 1].flatten(1).unsqueeze(2)  # B x L x 1
        A = torch.stack([x * M, y * M, x*x *M, y*y*M, x*y*M, torch.ones_like(y) * M], 2)  # B x L x 2

        theta_x, theta_y = lstq(A, F_u, F_v, lamda=.01)
        rec_flow_m = torch.stack([torch.einsum('bln,bnk->blk', A, theta_x).view(bs, *grid_x.shape),
                                      torch.einsum('bln,bnk->blk', A, theta_y).view(bs, *grid_y.shape)], 1)

        rec_flow += rec_flow_m
    return rec_flow


SUBSAMPLE = 8
SKIP = 0.4
SIZE = 0.3
NITER = 50
METHOD = 'inv_score'

def set_subsample_skip(sub=None, skip=None, size=None, niter=None, method=None):
    global SUBSAMPLE, SKIP, SIZE, NITER, METHOD
    if sub is not None: SUBSAMPLE=sub
    if skip is not None: SKIP=skip
    if size is not None: SIZE=size
    if niter is not None: NITER=niter
    if method is not None: METHOD=method