Spaces:
Runtime error
Runtime error
# A reimplemented version in public environments by Xiao Fu and Mu Hu | |
import numpy as np | |
from scipy.optimize import least_squares | |
import torch | |
def align_scale_shift(pred, target, clip_max): | |
mask = (target > 0) & (target < clip_max) | |
if mask.sum() > 10: | |
target_mask = target[mask] | |
pred_mask = pred[mask] | |
scale, shift = np.polyfit(pred_mask, target_mask, deg=1) | |
return scale, shift | |
else: | |
return 1, 0 | |
def align_scale(pred: torch.tensor, target: torch.tensor): | |
mask = target > 0 | |
if torch.sum(mask) > 10: | |
scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) | |
else: | |
scale = 1 | |
pred_scale = pred * scale | |
return pred_scale, scale | |
def align_shift(pred: torch.tensor, target: torch.tensor): | |
mask = target > 0 | |
if torch.sum(mask) > 10: | |
shift = torch.median(target[mask]) - (torch.median(pred[mask]) + 1e-8) | |
else: | |
shift = 0 | |
pred_shift = pred + shift | |
return pred_shift, shift |