superIX / superimage /utils.py
csaybar's picture
Upload 2 files
6c08128 verified
raw
history blame
No virus
1.34 kB
import torch
import numpy as np
from super_image import HanModel
from typing import Union
def create_superimage_model(
device: Union[str, torch.device] = "cuda"
) -> HanModel:
""" Create the super image model
Returns:
HanModel: The super image model
"""
return HanModel.from_pretrained('eugenesiow/han', scale=4).to(device)
def run_superimage(
model: HanModel,
lr: np.ndarray,
hr: np.ndarray,
device: Union[str, torch.device] = "cuda"
):
""" Run the super image model
Args:
model (HanModel): The super image model
lr (np.ndarray): The low resolution image
hr (np.ndarray): The high resolution image
device (Union[str, torch.device], optional): The device to run the model on. Defaults to "cuda".
Returns:
dict: The results
"""
# Convert the images to tensors
lr_tensor = (torch.from_numpy(lr[[3, 2, 1]]).to(device) / 2000).float()
# Run the model
with torch.no_grad():
sr_tensor = model(lr_tensor[None])
# Convert the tensors to numpy arrays
lr = (lr_tensor.cpu().numpy() * 2000).astype(np.uint16)
sr = (sr_tensor.cpu().numpy() * 2000).astype(np.uint16)
# Return the results
return {
"lr": lr.squeeze(),
"hr": hr[0:3].squeeze(),
"sr": sr.squeeze()
}