superIX / opensrmodel /utils.py
Cesar Aybar
up
c2f815f
raw
history blame
No virus
1.57 kB
import torch
import numpy as np
import opensr_model
from typing import Union
def create_opensr_model(
device: Union[str, torch.device] = "cpu"
) -> opensr_model:
""" Create the super image model
Returns:
HanModel: The super image model
"""
model = opensr_model.SRLatentDiffusion(device=device)
model.load_pretrained("./weights/opensr_10m_v4_v5.ckpt")
model.eval()
return model
def run_opensr_model(
model: opensr_model,
lr: np.ndarray,
hr: np.ndarray,
device: Union[str, torch.device] = "cpu"
) -> dict:
# Convert the input to torch tensors
lr_img = torch.from_numpy(lr[[3, 2, 1, 7]] / 10000).to(device).float()
hr_img = hr[0:3]
if lr_img.shape[1] == 121:
# add padding
lr_img = torch.nn.functional.pad(
lr_img[None],
pad=(3, 4, 3, 4),
mode='reflect'
).squeeze()
# Run the model
with torch.no_grad():
sr_img = model(lr_img[None]).squeeze()
# take out padding
lr_img = lr_img[:, 3:-4, 3:-4]
sr_img = sr_img[:, 3*4:-4*4, 3*4:-4*4]
else:
# Run the model
with torch.no_grad():
sr_img = model(lr_img[None]).squeeze()
# Convert the output to numpy
lr_img = (lr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16)
sr_img = (sr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16)
hr_img = hr_img
# Return the results
return {
"lr": lr_img,
"sr": sr_img,
"hr": hr_img
}