|
|
|
|
|
import numpy as np |
|
from scipy.optimize import least_squares |
|
import torch |
|
|
|
def align_scale_shift(pred, target, clip_max): |
|
mask = (target > 0) & (target < clip_max) |
|
if mask.sum() > 10: |
|
target_mask = target[mask] |
|
pred_mask = pred[mask] |
|
scale, shift = np.polyfit(pred_mask, target_mask, deg=1) |
|
return scale, shift |
|
else: |
|
return 1, 0 |
|
|
|
def align_scale(pred: torch.tensor, target: torch.tensor): |
|
mask = target > 0 |
|
if torch.sum(mask) > 10: |
|
scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) |
|
else: |
|
scale = 1 |
|
pred_scale = pred * scale |
|
return pred_scale, scale |
|
|
|
def align_shift(pred: torch.tensor, target: torch.tensor): |
|
mask = target > 0 |
|
if torch.sum(mask) > 10: |
|
shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8) |
|
else: |
|
shift = 0 |
|
pred_shift = pred + shift |
|
return pred_shift, shift |