Cesar Aybar commited on
Commit
61e0235
1 Parent(s): caa7010
Files changed (2) hide show
  1. benchmark.py +118 -6
  2. ldm_baseline/run.py +12 -6
benchmark.py CHANGED
@@ -1,29 +1,141 @@
1
- import rasterio
2
  import pathlib
 
 
3
 
4
  from typing import Callable
5
- from rasterio.transform import from_origin
6
 
7
 
8
  def create_geotiff(
 
9
  fn: Callable,
10
- dataset_snippet: str,
11
- output_path: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  ) -> pathlib.Path:
13
  """Create all the GeoTIFFs for a specific dataset snippet
14
 
15
  Args:
 
16
  fn (Callable): A function that return a dictionary with the following keys:
17
  - "lr": Low resolution image
18
  - "sr": Super resolution image
19
  - "hr": High resolution image
20
- dataset_snippet (str): The dataset snippet to use to run the fn function.
21
  output_path (str): The output path to save the GeoTIFFs.
 
 
22
 
23
  Returns:
24
  pathlib.Path: The output path where the GeoTIFFs are saved.
25
  """
26
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  def run(
 
1
+ import rasterio as rio
2
  import pathlib
3
+ import opensr_test
4
+ import matplotlib.pyplot as plt
5
 
6
  from typing import Callable
 
7
 
8
 
9
  def create_geotiff(
10
+ model: Callable,
11
  fn: Callable,
12
+ datasets: list,
13
+ output_path: str,
14
+ force: bool = False
15
+ ) -> None:
16
+ """Create all the GeoTIFFs for a specific dataset snippet
17
+
18
+ Args:
19
+ model (Callable): The model to use to run the fn function.
20
+ fn (Callable): A function that return a dictionary with the following keys:
21
+ - "lr": Low resolution image
22
+ - "sr": Super resolution image
23
+ - "hr": High resolution image
24
+ datasets (list): A list of dataset snippets to use to run the fn function.
25
+ output_path (str): The output path to save the GeoTIFFs.
26
+ force (bool, optional): If True, the dataset is redownloaded. Defaults
27
+ to False.
28
+ """
29
+ for snippet in datasets:
30
+ create_geotiff_batch(
31
+ model=model,
32
+ fn=fn,
33
+ snippet=snippet,
34
+ output_path=output_path,
35
+ force=force
36
+ )
37
+
38
+ return None
39
+
40
+ def create_geotiff_batch(
41
+ model: Callable,
42
+ fn: Callable,
43
+ snippet: str,
44
+ output_path: str,
45
+ force: bool = False
46
  ) -> pathlib.Path:
47
  """Create all the GeoTIFFs for a specific dataset snippet
48
 
49
  Args:
50
+ model (Callable): The model to use to run the fn function.
51
  fn (Callable): A function that return a dictionary with the following keys:
52
  - "lr": Low resolution image
53
  - "sr": Super resolution image
54
  - "hr": High resolution image
55
+ snippet (str): The dataset snippet to use to run the fn function.
56
  output_path (str): The output path to save the GeoTIFFs.
57
+ force (bool, optional): If True, the dataset is redownloaded. Defaults
58
+ to False.
59
 
60
  Returns:
61
  pathlib.Path: The output path where the GeoTIFFs are saved.
62
  """
63
+
64
+ # Create folders to save results
65
+ output_path = pathlib.Path(output_path) / "results" / "SR"
66
+ output_path.mkdir(parents=True, exist_ok=True)
67
+
68
+ output_path_dataset_geotiff = output_path / snippet / "geotiff"
69
+ output_path_dataset_geotiff.mkdir(parents=True, exist_ok=True)
70
+
71
+ output_path_dataset_png = output_path / snippet / "png"
72
+ output_path_dataset_png.mkdir(parents=True, exist_ok=True)
73
+
74
+ # Load the dataset
75
+ dataset = opensr_test.load(snippet, force=False)
76
+ lr_dataset, hr_dataset, metadata = dataset["L2A"], dataset["HRharm"], dataset["metadata"]
77
+ for index in range(len(lr_dataset)):
78
+ print(f"Processing {index}/{len(lr_dataset)}")
79
+
80
+ # Run the model
81
+ results = fn(
82
+ model=model,
83
+ lr=lr_dataset[index],
84
+ hr=hr_dataset[index]
85
+ )
86
+
87
+ # Get the image name
88
+ image_name = metadata.iloc[index]["hr_file"]
89
+
90
+ # Get the CRS and transform
91
+ crs = metadata.iloc[index]["crs"]
92
+ transform_str = metadata.iloc[index]["affine"]
93
+ transform_list = [float(x) for x in transform_str.split(",")]
94
+ transform_rio = rio.transform.from_origin(
95
+ transform_list[2],
96
+ transform_list[5],
97
+ transform_list[0],
98
+ transform_list[4] * -1
99
+ )
100
+
101
+ # Create rio dict
102
+ meta_img = {
103
+ "driver": "GTiff",
104
+ "count": 3,
105
+ "dtype": "uint16",
106
+ "height": results["hr"].shape[1],
107
+ "width": results["hr"].shape[2],
108
+ "crs": crs,
109
+ "transform": transform_rio,
110
+ "compress": "deflate",
111
+ "predictor": 2,
112
+ "tiled": True
113
+ }
114
+
115
+ # Save the GeoTIFF
116
+ with rio.open(output_path_dataset_geotiff / (image_name + ".tif"), "w", **meta_img) as dst:
117
+ dst.write(results["sr"])
118
+
119
+ # Save the PNG
120
+ fig, ax = plt.subplots(1, 3, figsize=(15, 5))
121
+ ax[0].imshow(results["lr"].transpose(1, 2, 0) / 3000)
122
+ ax[0].set_title("LR")
123
+ ax[0].axis("off")
124
+ ax[1].imshow(results["sr"].transpose(1, 2, 0) / 3000)
125
+ ax[1].set_title("SR")
126
+ ax[1].axis("off")
127
+ ax[2].imshow(results["hr"].transpose(1, 2, 0) / 3000)
128
+ ax[2].set_title("HR")
129
+ # remove whitespace around the image
130
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
131
+ plt.axis("off")
132
+ plt.savefig(output_path_dataset_png / (image_name + ".png"))
133
+ plt.close()
134
+ plt.clf()
135
+
136
+ return output_path_dataset_geotiff
137
+
138
+
139
 
140
 
141
  def run(
ldm_baseline/run.py CHANGED
@@ -1,5 +1,6 @@
1
  import matplotlib.pyplot as plt
2
  import opensr_test
 
3
 
4
  from ldm_baseline.utils import create_stable_diffusion_model, run_diffuser
5
 
@@ -10,11 +11,17 @@ device = "cuda:0"
10
  model = create_stable_diffusion_model(device=device)
11
 
12
  # Load the dataset
13
- dataset = opensr_test.load("spain_crops")
14
  lr_dataset, hr_dataset = dataset["L2A"], dataset["HRharm"]
15
 
16
  # Run the model
17
- results = run_diffuser(model=model, lr=lr_dataset[5], hr=hr_dataset[5], device=device)
 
 
 
 
 
 
18
 
19
  # Display the results
20
  fig, ax = plt.subplots(1, 3, figsize=(10, 5))
@@ -29,7 +36,6 @@ ax[2].set_title("HR")
29
  plt.show()
30
 
31
  # Run the experiment
32
- #
33
- # benchmark.create_geotiff(run_diffuser, "all", "ldm_baseline/")
34
- # benchmark.run("all")
35
- # benchmark.plot("all")
 
1
  import matplotlib.pyplot as plt
2
  import opensr_test
3
+ import benchmark
4
 
5
  from ldm_baseline.utils import create_stable_diffusion_model, run_diffuser
6
 
 
11
  model = create_stable_diffusion_model(device=device)
12
 
13
  # Load the dataset
14
+ dataset = opensr_test.load("naip", force=False)
15
  lr_dataset, hr_dataset = dataset["L2A"], dataset["HRharm"]
16
 
17
  # Run the model
18
+ index = 5
19
+ results = run_diffuser(
20
+ model=model,
21
+ lr=lr_dataset[index],
22
+ hr=hr_dataset[index],
23
+ device=device
24
+ )
25
 
26
  # Display the results
27
  fig, ax = plt.subplots(1, 3, figsize=(10, 5))
 
36
  plt.show()
37
 
38
  # Run the experiment
39
+ # benchmark.create_geotiff(model, run_diffuser, ["naip"], "ldm_baseline/")
40
+ # benchmark.run(["naip"])
41
+ # benchmark.plot(["naip"])