File size: 1,526 Bytes
9c55c41
 
 
 
 
 
 
 
 
 
 
7c04404
9c55c41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import tensorflow as tf
import torch

def load_cesbio_sr() -> tf.function:
    """Prepare the CESBIO model

    Returns:
        tf.function: A tf.function to get the SR image
    """

    # read the model
    model = tf.saved_model.load("sr4rs/weights/cesbio_model/sr4rs_sentinel2_bands4328_france2020_savedmodel")

    # get the signature
    signature = list(model.signatures.keys())[0]

    # get the function
    func = model.signatures[signature]

    return func

def run_sr4rs(
    model: tf.function,
    lr: tf.Tensor,
    hr: tf.Tensor,
) -> dict:
    """Run the SR4RS model

    Args:
        model (tf.function): The model to use
        lr (tf.Tensor): The low resolution image
        hr (tf.Tensor): The high resolution image
        cropsize (int, optional): The cropsize. Defaults to 32.
        overlap (int, optional): The overlap. Defaults to 0.

    Returns:
        dict: The results
    """
    # Run inference
    Xnp = torch.from_numpy(lr[[3, 2, 1, 7]][None]).permute(0, 2, 3, 1)
    Xtf = tf.convert_to_tensor(Xnp, dtype=tf.float32)
    pred = model(Xtf)

    # Save the results
    pred_np = pred['output_32:0'].numpy()
    pred_torch = torch.from_numpy(pred_np).permute(0, 3, 1, 2)
    pred_torch_padded = torch.nn.functional.pad(
        pred_torch,
        (32, 32, 32, 32),
        mode='constant',
        value=0,
    ).squeeze().numpy().astype('uint16')
    
    results = {
        "lr": lr[[3, 2, 1]],
        "sr": pred_torch_padded[0:3],
        "hr": hr[0:3],
    }

    return results