|
import torch |
|
import yaml |
|
|
|
from swin2_mose.model import Swin2MoSE |
|
|
|
|
|
def to_shape(t1, t2): |
|
t1 = t1[None].repeat(t2.shape[0], 1) |
|
t1 = t1.view((t2.shape[:2] + (1, 1))) |
|
return t1 |
|
|
|
|
|
def norm(tensor, mean, std): |
|
|
|
mean = torch.tensor(mean).to(tensor.device) |
|
std = torch.tensor(std).to(tensor.device) |
|
|
|
return (tensor - to_shape(mean, tensor)) / to_shape(std, tensor) |
|
|
|
|
|
def denorm(tensor, mean, std): |
|
|
|
mean = torch.tensor(mean).to(tensor.device) |
|
std = torch.tensor(std).to(tensor.device) |
|
|
|
return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor) |
|
|
|
|
|
def load_config(path): |
|
|
|
with open(path, 'r') as f: |
|
cfg = yaml.safe_load(f) |
|
return cfg |
|
|
|
|
|
def load_swin2_mose(model_weights, cfg): |
|
|
|
checkpoint = torch.load(model_weights) |
|
|
|
|
|
sr_model = Swin2MoSE(**cfg['super_res']['model']) |
|
sr_model.load_state_dict( |
|
checkpoint['model_state_dict']) |
|
|
|
sr_model.cfg = cfg |
|
|
|
return sr_model |
|
|
|
|
|
def run_swin2_mose(model, lr, hr, device='cuda'): |
|
|
|
cfg = model.cfg |
|
|
|
|
|
hr_stats = cfg['dataset']['stats']['tensor_05m_b2b3b4b8'] |
|
lr_stats = cfg['dataset']['stats']['tensor_10m_b2b3b4b8'] |
|
|
|
|
|
lr_orig = torch.from_numpy(lr)[None].float()[:, [3, 2, 1, 7]].to(device) |
|
hr_orig = torch.from_numpy(hr)[None].float().to(device) |
|
|
|
|
|
lr = norm(lr_orig, mean=lr_stats['mean'], std=lr_stats['std']) |
|
hr = norm(hr_orig, mean=hr_stats['mean'], std=hr_stats['std']) |
|
|
|
|
|
with torch.no_grad(): |
|
sr = model(lr) |
|
if not torch.is_tensor(sr): |
|
sr, _ = sr |
|
|
|
|
|
sr = denorm(sr, mean=hr_stats['mean'], std=hr_stats['std']) |
|
|
|
|
|
sr = sr.round().cpu().numpy().astype('uint16').squeeze()[0:3] |
|
lr = lr_orig[0].cpu().numpy().astype('uint16').squeeze()[0:3] |
|
hr = hr_orig[0].cpu().numpy().astype('uint16').squeeze()[0:3] |
|
|
|
|
|
|
|
if sr.shape[1] != hr.shape[1]: |
|
sr = torch.nn.functional.interpolate( |
|
torch.from_numpy(sr)[None].float(), |
|
size=hr.shape[1:], |
|
mode='nearest' |
|
).squeeze().numpy().astype('uint16') |
|
|
|
|
|
return { |
|
'lr': lr, |
|
'sr': sr, |
|
'hr': hr |
|
} |
|
|