File size: 1,568 Bytes
c2f815f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import numpy as np
import opensr_model
from typing import Union

def create_opensr_model(
    device: Union[str, torch.device] = "cpu"
) -> opensr_model:
    """ Create the super image model
    Returns:
        HanModel: The super image model
    """
    model = opensr_model.SRLatentDiffusion(device=device)
    model.load_pretrained("./weights/opensr_10m_v4_v5.ckpt")
    model.eval()
    return model


def run_opensr_model(
    model: opensr_model,
    lr: np.ndarray,
    hr: np.ndarray,
    device: Union[str, torch.device] = "cpu"
) -> dict:
    # Convert the input to torch tensors    
    lr_img = torch.from_numpy(lr[[3, 2, 1, 7]] / 10000).to(device).float()
    hr_img = hr[0:3]

    if lr_img.shape[1] == 121:
        # add padding
        lr_img = torch.nn.functional.pad(
            lr_img[None],
            pad=(3, 4, 3, 4),
            mode='reflect'
        ).squeeze()
                            
        # Run the model
        with torch.no_grad():
            sr_img = model(lr_img[None]).squeeze()

        # take out padding
        lr_img = lr_img[:, 3:-4, 3:-4]
        sr_img = sr_img[:, 3*4:-4*4, 3*4:-4*4]
    else:
        # Run the model
        with torch.no_grad():
            sr_img = model(lr_img[None]).squeeze()

    # Convert the output to numpy
    lr_img = (lr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16)
    sr_img = (sr_img.cpu().numpy()[0:3] * 10000).astype(np.uint16)
    hr_img = hr_img
    
    # Return the results
    return {
        "lr": lr_img,
        "sr": sr_img,
        "hr": hr_img
    }