import torch import yaml from model import Swin2MoSE def to_shape(t1, t2): t1 = t1[None].repeat(t2.shape[0], 1) t1 = t1.view((t2.shape[:2] + (1, 1))) return t1 def norm(tensor, mean, std): # get stats mean = torch.tensor(mean).to(tensor.device) std = torch.tensor(std).to(tensor.device) # denorm return (tensor - to_shape(mean, tensor)) / to_shape(std, tensor) def denorm(tensor, mean, std): # get stats mean = torch.tensor(mean).to(tensor.device) std = torch.tensor(std).to(tensor.device) # denorm return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor) def load_config(path): # load config with open(path, 'r') as f: cfg = yaml.safe_load(f) return cfg def load_swin2_mose(model_weights, cfg): # load checkpoint checkpoint = torch.load(model_weights) # build model sr_model = Swin2MoSE(**cfg['super_res']['model']) sr_model.load_state_dict( checkpoint['model_state_dict']) sr_model.cfg = cfg return sr_model def run_swin2_mose(model, lr, hr): cfg = model.cfg # norm fun hr_stats = cfg['dataset']['stats']['tensor_05m_b2b3b4b8'] lr_stats = cfg['dataset']['stats']['tensor_10m_b2b3b4b8'] # select 10m lr bands: B02, B03, B04, B08 and hr bands lr_orig = torch.tensor(lr)[None].float()[:, [3, 2, 1, 7]] hr_orig = torch.tensor(hr)[None].float() # normalize data lr = norm(lr_orig, mean=lr_stats['mean'], std=lr_stats['std']) hr = norm(hr_orig, mean=hr_stats['mean'], std=hr_stats['std']) # predict a image sr = model(lr) if not torch.is_tensor(sr): sr, _ = sr # denorm sr sr = denorm(sr, mean=hr_stats['mean'], std=hr_stats['std']) return { "lr": lr_orig[0], "sr": sr[0], "hr": hr_orig[0], }