File size: 1,004 Bytes
2e23827
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# A reimplemented version in public environments by Xiao Fu and Mu Hu

import numpy as np
from scipy.optimize import least_squares
import torch

def align_scale_shift(pred, target, clip_max):
    mask = (target > 0) & (target < clip_max)
    if mask.sum() > 10:
        target_mask = target[mask]
        pred_mask = pred[mask]
        scale, shift = np.polyfit(pred_mask, target_mask, deg=1)
        return scale, shift
    else:
        return 1, 0

def align_scale(pred: torch.tensor, target: torch.tensor):
    mask = target > 0
    if torch.sum(mask) > 10:
        scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8)
    else:
        scale = 1
    pred_scale = pred * scale
    return pred_scale, scale

def align_shift(pred: torch.tensor, target: torch.tensor):
    mask = target > 0
    if torch.sum(mask) > 10:
        shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8)
    else:
        shift = 0
    pred_shift = pred + shift
    return pred_shift, shift