File size: 5,269 Bytes
61e0235
caa7010
61e0235
 
caa7010
307a330
caa7010
 
 
61e0235
caa7010
307a330
61e0235
307a330
 
61e0235
 
 
 
 
 
 
 
 
 
 
 
 
 
307a330
 
 
 
61e0235
 
 
 
 
 
307a330
 
61e0235
 
 
 
 
 
 
 
 
307a330
 
caa7010
 
 
 
61e0235
caa7010
 
 
 
61e0235
caa7010
61e0235
 
caa7010
 
 
 
61e0235
 
 
 
 
 
 
 
 
 
 
 
307a330
61e0235
 
 
 
 
 
 
 
307a330
 
61e0235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c8cf1a
61e0235
 
7c8cf1a
61e0235
 
7c8cf1a
61e0235
 
 
 
 
 
 
 
 
 
 
caa7010
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307a330
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import rasterio as rio
import pathlib
import opensr_test
import matplotlib.pyplot as plt

from typing import Callable, Union


def create_geotiff(
    model: Callable,
    fn: Callable,
    datasets: Union[str, list],
    output_path: str,
    force: bool = False,
    **kwargs
) -> None:
    """Create all the GeoTIFFs for a specific dataset snippet 

    Args:
        model (Callable): The model to use to run the fn function.
        fn (Callable): A function that return a dictionary with the following keys:
            - "lr": Low resolution image
            - "sr": Super resolution image
            - "hr": High resolution image
        datasets (list): A list of dataset snippets to use to run the fn function.
        output_path (str): The output path to save the GeoTIFFs.
        force (bool, optional): If True, the dataset is redownloaded. Defaults 
            to False.
    """
    
    if datasets == "all":
        datasets = opensr_test.datasets 

    for snippet in datasets:
        create_geotiff_batch(
            model=model,
            fn=fn,
            snippet=snippet,
            output_path=output_path,
            force=force,
            **kwargs
        )    

    return None

def create_geotiff_batch(
    model: Callable,
    fn: Callable,
    snippet: str,
    output_path: str,
    force: bool = False,
    **kwargs
) -> pathlib.Path:
    """Create all the GeoTIFFs for a specific dataset snippet 

    Args:
        model (Callable): The model to use to run the fn function.
        fn (Callable): A function that return a dictionary with the following keys:
            - "lr": Low resolution image
            - "sr": Super resolution image
            - "hr": High resolution image
        snippet (str): The dataset snippet to use to run the fn function.
        output_path (str): The output path to save the GeoTIFFs.
        force (bool, optional): If True, the dataset is redownloaded. Defaults 
            to False.

    Returns:
        pathlib.Path: The output path where the GeoTIFFs are saved.
    """
    
    # Create folders to save results
    output_path = pathlib.Path(output_path)  / "results" / "SR"
    output_path.mkdir(parents=True, exist_ok=True)

    output_path_dataset_geotiff = output_path / snippet / "geotiff"
    output_path_dataset_geotiff.mkdir(parents=True, exist_ok=True)

    output_path_dataset_png = output_path / snippet / "png"
    output_path_dataset_png.mkdir(parents=True, exist_ok=True)

    # Load the dataset 
    dataset = opensr_test.load(snippet, force=force)
    lr_dataset, hr_dataset, metadata = dataset["L2A"], dataset["HRharm"], dataset["metadata"]
    for index in range(len(lr_dataset)):
        print(f"Processing {index}/{len(lr_dataset)}")

        # Run the model    
        results = fn(
            model=model,
            lr=lr_dataset[index],
            hr=hr_dataset[index],
            **kwargs
        )

        # Get the image name
        image_name = metadata.iloc[index]["hr_file"]

        # Get the CRS and transform
        crs = metadata.iloc[index]["crs"]
        transform_str = metadata.iloc[index]["affine"]
        transform_list = [float(x) for x in transform_str.split(",")]
        transform_rio = rio.transform.from_origin(
            transform_list[2],
            transform_list[5],
            transform_list[0],
            transform_list[4] * -1
        )

        # Create rio dict
        meta_img = {
            "driver": "GTiff",
            "count": 3,
            "dtype": "uint16",
            "height": results["hr"].shape[1],
            "width": results["hr"].shape[2],
            "crs": crs,
            "transform": transform_rio,
            "compress": "deflate",
            "predictor": 2,
            "tiled": True
        }

        # Save the GeoTIFF
        with rio.open(output_path_dataset_geotiff / (image_name + ".tif"), "w", **meta_img) as dst:
            dst.write(results["sr"])

        # Save the PNG
        fig, ax = plt.subplots(1, 3, figsize=(15, 5))
        ax[0].imshow((results["lr"].transpose(1, 2, 0) / 3000).clip(0, 1))
        ax[0].set_title("LR")
        ax[0].axis("off")
        ax[1].imshow((results["sr"].transpose(1, 2, 0) / 3000).clip(0, 1))
        ax[1].set_title("SR")
        ax[1].axis("off")
        ax[2].imshow((results["hr"].transpose(1, 2, 0) / 3000).clip(0, 1))
        ax[2].set_title("HR")
        # remove whitespace around the image
        plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
        plt.axis("off")
        plt.savefig(output_path_dataset_png / (image_name + ".png"))
        plt.close()
        plt.clf()

    return output_path_dataset_geotiff




def run(
    model_path: str
) -> pathlib.Path:
    """Run the all metrics for a specific model.

    Args:
        model_path (str): The path to the model folder.
    
    Returns:
        pathlib.Path: The output path where the metrics are 
        saved as a pickle file.
    """
    pass


def plot(
    model_path: str
) -> pathlib.Path:
    """Generate the plots and tables for a specific model.

    Args:
        model_path (str): The path to the model folder.
    
    Returns:
        pathlib.Path: The output path where the plots and tables are 
        saved.
    """
    pass