|
|
import pathlib |
|
|
import safetensors.torch |
|
|
import matplotlib.pyplot as plt |
|
|
from sen2sr.models.opensr_baseline.swin import Swin2SR |
|
|
from sen2sr.models.tricks import HardConstraint |
|
|
from sen2sr.referencex2 import srmodel |
|
|
|
|
|
|
|
|
def example_data(path: pathlib.Path, *args, **kwargs): |
|
|
data_f = path / "example_data.safetensor" |
|
|
sample = safetensors.torch.load_file(data_f) |
|
|
return sample["lr"] |
|
|
|
|
|
def trainable_model(path, device: str = "cpu", *args, **kwargs): |
|
|
trainable_f = path / "model.safetensor" |
|
|
|
|
|
|
|
|
sr_model_weights = safetensors.torch.load_file(trainable_f) |
|
|
params = { |
|
|
"img_size": (64, 64), |
|
|
"in_channels": 10, |
|
|
"out_channels": 6, |
|
|
"embed_dim": 192, |
|
|
"depths": [8] * 8, |
|
|
"num_heads": [8] * 8, |
|
|
"window_size": 4, |
|
|
"mlp_ratio": 4.0, |
|
|
"upscale": 1, |
|
|
"resi_connection": "1conv", |
|
|
"upsampler": "pixelshuffle", |
|
|
} |
|
|
sr_model = Swin2SR(**params) |
|
|
sr_model.load_state_dict(sr_model_weights) |
|
|
sr_model.to(device) |
|
|
|
|
|
|
|
|
hard_constraint_weights = safetensors.torch.load_file(path / "hard_constraint.safetensor") |
|
|
hard_constraint = HardConstraint( |
|
|
low_pass_mask=hard_constraint_weights["weights"].to(device), |
|
|
bands= [0, 1, 2, 3, 4, 5], |
|
|
device=device |
|
|
) |
|
|
|
|
|
return srmodel(sr_model=sr_model, hard_constraint=hard_constraint, device=device) |
|
|
|
|
|
|
|
|
def compiled_model(path, device: str = "cpu", *args, **kwargs): |
|
|
trainable_f = path / "model.safetensor" |
|
|
|
|
|
|
|
|
sr_model_weights = safetensors.torch.load_file(trainable_f) |
|
|
params = { |
|
|
"img_size": (64, 64), |
|
|
"in_channels": 10, |
|
|
"out_channels": 6, |
|
|
"embed_dim": 192, |
|
|
"depths": [8] * 8, |
|
|
"num_heads": [8] * 8, |
|
|
"window_size": 4, |
|
|
"mlp_ratio": 4.0, |
|
|
"upscale": 1, |
|
|
"resi_connection": "1conv", |
|
|
"upsampler": "pixelshuffle", |
|
|
} |
|
|
sr_model = Swin2SR(**params) |
|
|
sr_model.load_state_dict(sr_model_weights) |
|
|
sr_model = sr_model.to(device) |
|
|
sr_model = sr_model.eval() |
|
|
for param in sr_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
hard_constraint_weights = safetensors.torch.load_file(path / "hard_constraint.safetensor") |
|
|
hard_constraint = HardConstraint( |
|
|
low_pass_mask=hard_constraint_weights["weights"].to(device), |
|
|
bands= [0, 1, 2, 3, 4, 5], |
|
|
device=device |
|
|
) |
|
|
hard_constraint = hard_constraint.eval() |
|
|
for param in hard_constraint.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
return srmodel(sr_model=sr_model, hard_constraint=hard_constraint, device=device) |
|
|
|
|
|
|
|
|
def display_results(path: pathlib.Path, device: str = "cpu", *args, **kwargs): |
|
|
|
|
|
model = compiled_model(path, device) |
|
|
|
|
|
|
|
|
lr = example_data(path) |
|
|
|
|
|
|
|
|
sr = model(lr.to(device)) |
|
|
|
|
|
|
|
|
lr_rgb = lr[0, [2, 1, 0]].cpu().numpy().transpose(1, 2, 0) |
|
|
sr_rgb = sr[0, [2, 1, 0]].cpu().numpy().transpose(1, 2, 0) |
|
|
|
|
|
lr_swirs = lr[0, [9, 8, 7]].cpu().numpy().transpose(1, 2, 0) |
|
|
sr_swirs = sr[0, [9, 8, 7]].cpu().numpy().transpose(1, 2, 0) |
|
|
|
|
|
lr_reds = lr[0, [6, 5, 4]].cpu().numpy().transpose(1, 2, 0) |
|
|
sr_reds = sr[0, [6, 5, 4]].cpu().numpy().transpose(1, 2, 0) |
|
|
|
|
|
|
|
|
|
|
|
lr_slice = slice(16, 32+80) |
|
|
hr_slice = slice(lr_slice.start*1, lr_slice.stop*1) |
|
|
fig, ax = plt.subplots(3, 2, figsize=(8, 12)) |
|
|
ax = ax.flatten() |
|
|
ax[0].imshow(lr_rgb[lr_slice]*2) |
|
|
ax[0].set_title("LR RGB") |
|
|
ax[1].imshow(sr_rgb[hr_slice]*2) |
|
|
ax[1].set_title("SR RGB") |
|
|
ax[2].imshow(lr_swirs[lr_slice]*2) |
|
|
ax[2].set_title("LR SWIR") |
|
|
ax[3].imshow(sr_swirs[hr_slice]*2) |
|
|
ax[3].set_title("SR SWIR") |
|
|
ax[4].imshow(lr_reds[lr_slice]*2) |
|
|
ax[4].set_title("LR RED") |
|
|
ax[5].imshow(sr_reds[hr_slice]*2) |
|
|
ax[5].set_title("SR RED") |
|
|
for a in ax: |
|
|
a.axis("off") |
|
|
fig.tight_layout() |
|
|
return fig |