csaybar commited on
Commit
26d106d
1 Parent(s): 99a3901

Upload 3 files

Browse files
Files changed (2) hide show
  1. satlas/run.py +18 -7
  2. satlas/utils.py +19 -2
satlas/run.py CHANGED
@@ -11,12 +11,23 @@ dataset = opensr_test.load("naip")
11
  lr_dataset, hr_dataset = dataset["L1C"], dataset["HRharm"]
12
 
13
  # Predict a image
14
- index = 20
15
- lr = torch.from_numpy(lr_dataset[index][[3, 2, 1]]/3558).float().to("cuda").clamp(0, 1)
16
- sr = run_satlas(model=model, lr=lr, cropsize=32, overlap=0)
 
 
 
 
17
 
18
- # Run the model
19
- fig, ax = plt.subplots(1, 2, figsize=(10, 5))
20
- ax[0].imshow(lr.cpu().numpy().transpose(1, 2, 0))
21
- ax[1].imshow(sr.cpu().numpy().transpose(1, 2, 0))
 
 
 
 
 
 
22
  plt.show()
 
 
11
  lr_dataset, hr_dataset = dataset["L1C"], dataset["HRharm"]
12
 
13
  # Predict a image
14
+ results = run_satlas(
15
+ model=model,
16
+ lr=lr_dataset[4],
17
+ hr=hr_dataset[4],
18
+ cropsize=32,
19
+ overlap=0
20
+ )
21
 
22
+ # Display the results
23
+ fig, ax = plt.subplots(1, 3, figsize=(10, 5))
24
+ ax[0].imshow(results["lr"].transpose(1, 2, 0)/10000)
25
+ ax[0].set_title("LR")
26
+ ax[0].axis("off")
27
+ ax[1].imshow(results["sr"].transpose(1, 2, 0)/10000)
28
+ ax[1].set_title("SR")
29
+ ax[1].axis("off")
30
+ ax[2].imshow(results["hr"].transpose(1, 2, 0) / 3000)
31
+ ax[2].set_title("HR")
32
  plt.show()
33
+
satlas/utils.py CHANGED
@@ -32,7 +32,17 @@ def load_satlas_sr(device: Union[str, torch.device] = "cuda") -> RRDBNet:
32
  return model
33
 
34
 
35
- def run_satlas(model, lr, cropsize: int = 32, overlap: int = 0):
 
 
 
 
 
 
 
 
 
 
36
  # Select the raster with the lowest resolution
37
  tshp = lr.shape
38
 
@@ -73,4 +83,11 @@ def run_satlas(model, lr, cropsize: int = 32, overlap: int = 0):
73
  sr_crop = model(crop[None])[0]
74
  sr[:, x*4:(x+cropsize)*4, y*4:(y+cropsize)*4] = sr_crop
75
 
76
- return sr
 
 
 
 
 
 
 
 
32
  return model
33
 
34
 
35
+ def run_satlas(
36
+ model: RRDBNet,
37
+ lr: torch.Tensor,
38
+ hr: torch.Tensor,
39
+ cropsize: int = 32,
40
+ overlap: int = 0,
41
+ device: Union[str, torch.device] = "cuda"
42
+ ) -> torch.Tensor:
43
+ # Load the LR image
44
+ lr = torch.from_numpy(lr[[3, 2, 1]]/3558).float().to(device).clamp(0, 1)
45
+
46
  # Select the raster with the lowest resolution
47
  tshp = lr.shape
48
 
 
83
  sr_crop = model(crop[None])[0]
84
  sr[:, x*4:(x+cropsize)*4, y*4:(y+cropsize)*4] = sr_crop
85
 
86
+ # Save the result
87
+ results = {
88
+ "lr": (lr.cpu().numpy() * 10000).astype(np.uint16),
89
+ "sr": (sr.cpu().numpy() * 10000).astype(np.uint16),
90
+ "hr": hr[0:3]
91
+ }
92
+
93
+ return results