File size: 1,338 Bytes
6c08128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import numpy as np
from super_image import HanModel
from typing import Union

def create_superimage_model(
    device: Union[str, torch.device] = "cuda"
) -> HanModel:
    """ Create the super image model

    Returns:
        HanModel: The super image model
    """
    return HanModel.from_pretrained('eugenesiow/han', scale=4).to(device)


def run_superimage(
    model: HanModel,
    lr: np.ndarray,
    hr: np.ndarray,
    device: Union[str, torch.device] = "cuda"    
):
    """ Run the super image model

    Args:
        model (HanModel): The super image model
        lr (np.ndarray): The low resolution image
        hr (np.ndarray): The high resolution image
        device (Union[str, torch.device], optional): The device to run the model on. Defaults to "cuda".

    Returns:
        dict: The results
    """
    # Convert the images to tensors
    lr_tensor = (torch.from_numpy(lr[[3, 2, 1]]).to(device) / 2000).float()
    
    # Run the model
    with torch.no_grad():
        sr_tensor = model(lr_tensor[None])

    # Convert the tensors to numpy arrays
    lr = (lr_tensor.cpu().numpy() * 2000).astype(np.uint16)
    sr = (sr_tensor.cpu().numpy() * 2000).astype(np.uint16)

    # Return the results
    return {
        "lr": lr.squeeze(),
        "hr": hr[0:3].squeeze(),
        "sr": sr.squeeze()
    }