carlosgomes98 commited on
Commit
c2435ef
1 Parent(s): ffd2463

fix inference

Browse files
Files changed (3) hide show
  1. Prithvi_100M_config.yaml +18 -18
  2. Prithvi_run_inference.py +231 -113
  3. README.md +3 -1
Prithvi_100M_config.yaml CHANGED
@@ -12,25 +12,25 @@ model_args:
12
  tubelet_size: 1
13
  train_params:
14
  bands:
15
- - B02
16
- - B03
17
- - B04
18
- - B05
19
- - B06
20
- - B07
21
  data_mean:
22
- - 775.2290211032589
23
- - 1080.992780391705
24
- - 1228.5855250417867
25
- - 2497.2022620507532
26
- - 2204.2139147975554
27
- - 1610.8324823273745
28
  data_std:
29
- - 1281.526139861424
30
- - 1270.0297974547493
31
- - 1399.4802505642526
32
- - 1368.3446143747644
33
- - 1291.6764008585435
34
- - 1154.505683480695
35
  mask_ratio: 0.75
36
  random_cropping: true
 
12
  tubelet_size: 1
13
  train_params:
14
  bands:
15
+ - B02
16
+ - B03
17
+ - B04
18
+ - B05
19
+ - B06
20
+ - B07
21
  data_mean:
22
+ - 775.2290211032589
23
+ - 1080.992780391705
24
+ - 1228.5855250417867
25
+ - 2497.2022620507532
26
+ - 2204.2139147975554
27
+ - 1610.8324823273745
28
  data_std:
29
+ - 1281.526139861424
30
+ - 1270.0297974547493
31
+ - 1399.4802505642526
32
+ - 1368.3446143747644
33
+ - 1291.6764008585435
34
+ - 1154.505683480695
35
  mask_ratio: 0.75
36
  random_cropping: true
Prithvi_run_inference.py CHANGED
@@ -1,7 +1,7 @@
1
  import argparse
2
  import functools
3
  import os
4
- from typing import List
5
 
6
  import numpy as np
7
  import rasterio
@@ -17,7 +17,7 @@ PERCENTILES = (0.1, 99.9)
17
 
18
 
19
  def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
20
- """ Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
21
  original range using *data_mean* and *data_std* and then lowest and highest percentiles are
22
  removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
23
 
@@ -65,7 +65,7 @@ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
65
 
66
 
67
  def read_geotiff(file_path: str):
68
- """ Read all bands from *file_path* and return image + meta info.
69
 
70
  Args:
71
  file_path: path to image file.
@@ -83,7 +83,7 @@ def read_geotiff(file_path: str):
83
 
84
 
85
  def save_geotiff(image, output_path: str, meta: dict):
86
- """ Save multi-band image in Geotiff file.
87
 
88
  Args:
89
  image: np.ndarray with shape (bands, height, width)
@@ -99,15 +99,19 @@ def save_geotiff(image, output_path: str, meta: dict):
99
 
100
 
101
  def _convert_np_uint8(float_image: torch.Tensor):
102
-
103
  image = float_image.numpy() * 255.0
104
  image = image.astype(dtype=np.uint8)
105
 
106
  return image
107
 
108
 
109
- def load_example(file_paths: List[str], mean: List[float], std: List[float]):
110
- """ Build an input example by loading images in *file_paths*.
 
 
 
 
 
111
 
112
  Args:
113
  file_paths: list of file paths .
@@ -126,21 +130,28 @@ def load_example(file_paths: List[str], mean: List[float], std: List[float]):
126
  img, meta = read_geotiff(file)
127
 
128
  # Rescaling (don't normalize on nodata)
129
- img = np.moveaxis(img, 0, -1) # channels last for rescaling
 
 
130
  img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
131
 
132
  imgs.append(img)
133
  metas.append(meta)
134
 
135
- imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
136
- imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, H, W
137
  imgs = np.expand_dims(imgs, axis=0) # add batch dim
138
 
139
  return imgs, metas
140
 
141
 
142
- def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
143
- """ Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
 
 
 
 
 
144
 
145
  Args:
146
  model: MAE model to run.
@@ -158,12 +169,16 @@ def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: floa
158
  _, pred, mask = model(x, mask_ratio)
159
 
160
  # Create mask and prediction images (un-patchify)
161
- mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
 
 
162
  pred_img = model.unpatchify(pred).detach().cpu()
163
 
164
  # Mix visible and predicted patches
165
  rec_img = input_data.clone()
166
- rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove
 
 
167
 
168
  # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
169
  mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
@@ -171,8 +186,10 @@ def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: floa
171
  return rec_img, mask_img
172
 
173
 
174
- def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
175
- """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
 
 
176
 
177
  Args:
178
  input_img: input torch.Tensor with shape (C, T, H, W).
@@ -186,30 +203,39 @@ def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir,
186
  """
187
 
188
  for t in range(input_img.shape[1]):
189
- rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
190
- new_img=rec_img[:, t, :, :],
191
- channels=channels, data_mean=mean,
192
- data_std=std)
 
 
 
193
 
194
  rgb_mask = mask_img[channels, t, :, :] * rgb_orig
195
 
196
  # Saving images
197
 
198
- save_geotiff(image=_convert_np_uint8(rgb_orig),
199
- output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
200
- meta=meta_data[t])
 
 
201
 
202
- save_geotiff(image=_convert_np_uint8(rgb_pred),
203
- output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
204
- meta=meta_data[t])
 
 
205
 
206
- save_geotiff(image=_convert_np_uint8(rgb_mask),
207
- output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
208
- meta=meta_data[t])
 
 
209
 
210
 
211
  def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
212
- """ Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
213
 
214
  Args:
215
  rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
@@ -224,7 +250,6 @@ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
224
  std = torch.tensor(np.asarray(std)[:, None, None])
225
 
226
  for t in range(rec_img.shape[1]):
227
-
228
  # Back to original data range
229
  rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
230
 
@@ -232,78 +257,98 @@ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
232
 
233
  # Saving images
234
 
235
- save_geotiff(image=rec_img_t,
236
- output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
237
- meta=meta_data[t])
238
-
239
- save_geotiff(image=mask_img_t,
240
- output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
241
- meta=meta_data[t])
242
-
243
-
244
- def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir: str,
245
- mask_ratio: float, rgb_outputs: bool):
246
-
 
 
 
 
 
 
 
 
 
 
 
247
  os.makedirs(output_dir, exist_ok=True)
248
 
249
  # Get parameters --------
250
 
251
- with open(yaml_file_path, 'r') as f:
252
  params = yaml.safe_load(f)
253
 
254
  # data related
 
255
  num_frames = len(data_files)
256
- img_size = params['img_size']
257
- bands = params['bands']
258
- mean = params['data_mean']
259
- std = params['data_std']
260
 
261
  # model related
262
- depth = params['depth']
263
- patch_size = params['patch_size']
264
- embed_dim = params['embed_dim']
265
- num_heads = params['num_heads']
266
- tubelet_size = params['tubelet_size']
267
- decoder_embed_dim = params['decoder_embed_dim']
268
- decoder_num_heads = params['decoder_num_heads']
269
- decoder_depth = params['decoder_depth']
270
-
271
- batch_size = params['batch_size']
272
-
273
- mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
274
-
275
- print(f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n")
 
 
 
 
276
  if len(data_files) != 3:
277
- print("The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary")
 
 
278
 
279
  if torch.cuda.is_available():
280
- device = torch.device('cuda')
281
  else:
282
- device = torch.device('cpu')
283
 
284
  print(f"Using {device} device.\n")
285
 
286
  # Loading data ---------------------------------------------------------------------------------
287
 
288
- input_data, meta_data = load_example(file_paths=data_files, mean=mean, std=std)
 
 
289
 
290
  # Create model and load checkpoint -------------------------------------------------------------
291
 
292
  model = MaskedAutoencoderViT(
293
- img_size=img_size,
294
- patch_size=patch_size,
295
- num_frames=num_frames,
296
- tubelet_size=tubelet_size,
297
- in_chans=len(bands),
298
- embed_dim=embed_dim,
299
- depth=depth,
300
- num_heads=num_heads,
301
- decoder_embed_dim=decoder_embed_dim,
302
- decoder_depth=decoder_depth,
303
- decoder_num_heads=decoder_num_heads,
304
- mlp_ratio=4.,
305
- norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
306
- norm_pix_loss=False)
 
307
 
308
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
309
  print(f"\n--> Model has {total_params:,} parameters.\n")
@@ -312,27 +357,31 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
312
 
313
  state_dict = torch.load(checkpoint, map_location=device)
314
  # discard fixed pos_embedding weight
315
- del state_dict['pos_embed']
316
- del state_dict['decoder_pos_embed']
317
  model.load_state_dict(state_dict, strict=False)
318
  print(f"Loaded checkpoint from {checkpoint}")
319
 
320
  # Running model --------------------------------------------------------------------------------
321
 
322
  model.eval()
323
- channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
324
 
325
  # Reflect pad if not divisible by img_size
326
  original_h, original_w = input_data.shape[-2:]
327
  pad_h = img_size - (original_h % img_size)
328
  pad_w = img_size - (original_w % img_size)
329
- input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect')
 
 
330
 
331
  # Build sliding window
332
- batch = torch.tensor(input_data, device='cpu')
333
  windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
334
  h1, w1 = windows.shape[3:5]
335
- windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size)
 
 
336
 
337
  # Split into batches if number of windows > batch_size
338
  num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
@@ -350,10 +399,28 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
350
  mask_imgs = torch.concat(mask_imgs, dim=0)
351
 
352
  # Build images from patches
353
- rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
354
- h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
355
- mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
356
- h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
  # Cut padded images back to original size
359
  rec_imgs_full = rec_imgs[..., :original_h, :original_w]
@@ -363,37 +430,88 @@ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir
363
  # Build output images
364
  if rgb_outputs:
365
  for d in meta_data:
366
- d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
367
-
368
- save_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
369
- channels, mean, std, output_dir, meta_data)
 
 
 
 
 
 
 
 
370
  else:
371
  for d in meta_data:
372
- d.update(compress='lzw', nodata=0)
373
 
374
- save_imgs(rec_imgs_full[0, ...], mask_imgs_full[0, ...], mean, std, output_dir, meta_data)
 
 
 
 
 
 
 
375
 
376
  print("Done!")
377
 
378
 
379
  if __name__ == "__main__":
380
- parser = argparse.ArgumentParser('MAE run inference', add_help=False)
381
-
382
- parser.add_argument('--data_files', required=True, type=str, nargs='+',
383
- help='Path to the data files. Assumes multi-band files.')
384
- parser.add_argument('--yaml_file_path', type=str, required=True,
385
- help='Path to yaml file containing model training parameters.')
386
- parser.add_argument('--checkpoint', required=True, type=str,
387
- help='Path to a checkpoint file to load from.')
388
- parser.add_argument('--output_dir', required=True, type=str,
389
- help='Path to the directory where to save outputs.')
390
- parser.add_argument('--mask_ratio', default=None, type=float,
391
- help='Masking ratio (percentage of removed patches). '
392
- 'If None (default) use same value used for pretraining.')
393
- parser.add_argument('--rgb_outputs', action='store_true',
394
- help='If present, output files will only contain RGB channels. '
395
- 'Otherwise, all bands will be saved.')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
396
  args = parser.parse_args()
397
 
398
  main(**vars(args))
399
-
 
1
  import argparse
2
  import functools
3
  import os
4
+ from typing import List, Union
5
 
6
  import numpy as np
7
  import rasterio
 
17
 
18
 
19
  def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
20
+ """Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
21
  original range using *data_mean* and *data_std* and then lowest and highest percentiles are
22
  removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
23
 
 
65
 
66
 
67
  def read_geotiff(file_path: str):
68
+ """Read all bands from *file_path* and return image + meta info.
69
 
70
  Args:
71
  file_path: path to image file.
 
83
 
84
 
85
  def save_geotiff(image, output_path: str, meta: dict):
86
+ """Save multi-band image in Geotiff file.
87
 
88
  Args:
89
  image: np.ndarray with shape (bands, height, width)
 
99
 
100
 
101
  def _convert_np_uint8(float_image: torch.Tensor):
 
102
  image = float_image.numpy() * 255.0
103
  image = image.astype(dtype=np.uint8)
104
 
105
  return image
106
 
107
 
108
+ def load_example(
109
+ file_paths: List[str],
110
+ mean: List[float],
111
+ std: List[float],
112
+ indices: Union[list[int], None] = None,
113
+ ):
114
+ """Build an input example by loading images in *file_paths*.
115
 
116
  Args:
117
  file_paths: list of file paths .
 
130
  img, meta = read_geotiff(file)
131
 
132
  # Rescaling (don't normalize on nodata)
133
+ img = np.moveaxis(img, 0, -1) # channels last for rescaling
134
+ if indices is not None:
135
+ img = img[..., indices]
136
  img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
137
 
138
  imgs.append(img)
139
  metas.append(meta)
140
 
141
+ imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
142
+ imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
143
  imgs = np.expand_dims(imgs, axis=0) # add batch dim
144
 
145
  return imgs, metas
146
 
147
 
148
+ def run_model(
149
+ model: torch.nn.Module,
150
+ input_data: torch.Tensor,
151
+ mask_ratio: float,
152
+ device: torch.device,
153
+ ):
154
+ """Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
155
 
156
  Args:
157
  model: MAE model to run.
 
169
  _, pred, mask = model(x, mask_ratio)
170
 
171
  # Create mask and prediction images (un-patchify)
172
+ mask_img = (
173
+ model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
174
+ )
175
  pred_img = model.unpatchify(pred).detach().cpu()
176
 
177
  # Mix visible and predicted patches
178
  rec_img = input_data.clone()
179
+ rec_img[mask_img == 1] = pred_img[
180
+ mask_img == 1
181
+ ] # binary mask: 0 is keep, 1 is remove
182
 
183
  # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
184
  mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
 
186
  return rec_img, mask_img
187
 
188
 
189
+ def save_rgb_imgs(
190
+ input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data
191
+ ):
192
+ """Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
193
 
194
  Args:
195
  input_img: input torch.Tensor with shape (C, T, H, W).
 
203
  """
204
 
205
  for t in range(input_img.shape[1]):
206
+ rgb_orig, rgb_pred = process_channel_group(
207
+ orig_img=input_img[:, t, :, :],
208
+ new_img=rec_img[:, t, :, :],
209
+ channels=channels,
210
+ data_mean=mean,
211
+ data_std=std,
212
+ )
213
 
214
  rgb_mask = mask_img[channels, t, :, :] * rgb_orig
215
 
216
  # Saving images
217
 
218
+ save_geotiff(
219
+ image=_convert_np_uint8(rgb_orig),
220
+ output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
221
+ meta=meta_data[t],
222
+ )
223
 
224
+ save_geotiff(
225
+ image=_convert_np_uint8(rgb_pred),
226
+ output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
227
+ meta=meta_data[t],
228
+ )
229
 
230
+ save_geotiff(
231
+ image=_convert_np_uint8(rgb_mask),
232
+ output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
233
+ meta=meta_data[t],
234
+ )
235
 
236
 
237
  def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
238
+ """Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
239
 
240
  Args:
241
  rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
 
250
  std = torch.tensor(np.asarray(std)[:, None, None])
251
 
252
  for t in range(rec_img.shape[1]):
 
253
  # Back to original data range
254
  rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
255
 
 
257
 
258
  # Saving images
259
 
260
+ save_geotiff(
261
+ image=rec_img_t,
262
+ output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
263
+ meta=meta_data[t],
264
+ )
265
+
266
+ save_geotiff(
267
+ image=mask_img_t,
268
+ output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
269
+ meta=meta_data[t],
270
+ )
271
+
272
+
273
+ def main(
274
+ data_files: List[str],
275
+ yaml_file_path: str,
276
+ checkpoint: str,
277
+ output_dir: str,
278
+ rgb_outputs: bool,
279
+ img_size: int,
280
+ mask_ratio: float = None,
281
+ input_indices: list[int] = None,
282
+ ):
283
  os.makedirs(output_dir, exist_ok=True)
284
 
285
  # Get parameters --------
286
 
287
+ with open(yaml_file_path, "r") as f:
288
  params = yaml.safe_load(f)
289
 
290
  # data related
291
+ train_params = params["train_params"]
292
  num_frames = len(data_files)
293
+ bands = train_params["bands"]
294
+ mean = train_params["data_mean"]
295
+ std = train_params["data_std"]
 
296
 
297
  # model related
298
+ model_params = params["model_args"]
299
+ img_size = model_params["img_size"] if img_size is None else img_size
300
+ depth = model_params["depth"]
301
+ patch_size = model_params["patch_size"]
302
+ embed_dim = model_params["embed_dim"]
303
+ num_heads = model_params["num_heads"]
304
+ tubelet_size = model_params["tubelet_size"]
305
+ decoder_embed_dim = model_params["decoder_embed_dim"]
306
+ decoder_num_heads = model_params["decoder_num_heads"]
307
+ decoder_depth = model_params["decoder_depth"]
308
+
309
+ batch_size = 1
310
+
311
+ mask_ratio = train_params["mask_ratio"] if mask_ratio is None else mask_ratio
312
+
313
+ print(
314
+ f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
315
+ )
316
  if len(data_files) != 3:
317
+ print(
318
+ "The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary"
319
+ )
320
 
321
  if torch.cuda.is_available():
322
+ device = torch.device("cuda")
323
  else:
324
+ device = torch.device("cpu")
325
 
326
  print(f"Using {device} device.\n")
327
 
328
  # Loading data ---------------------------------------------------------------------------------
329
 
330
+ input_data, meta_data = load_example(
331
+ file_paths=data_files, indices=input_indices, mean=mean, std=std
332
+ )
333
 
334
  # Create model and load checkpoint -------------------------------------------------------------
335
 
336
  model = MaskedAutoencoderViT(
337
+ img_size=img_size,
338
+ patch_size=patch_size,
339
+ num_frames=num_frames,
340
+ tubelet_size=tubelet_size,
341
+ in_chans=len(bands),
342
+ embed_dim=embed_dim,
343
+ depth=depth,
344
+ num_heads=num_heads,
345
+ decoder_embed_dim=decoder_embed_dim,
346
+ decoder_depth=decoder_depth,
347
+ decoder_num_heads=decoder_num_heads,
348
+ mlp_ratio=4.0,
349
+ norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
350
+ norm_pix_loss=False,
351
+ )
352
 
353
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
354
  print(f"\n--> Model has {total_params:,} parameters.\n")
 
357
 
358
  state_dict = torch.load(checkpoint, map_location=device)
359
  # discard fixed pos_embedding weight
360
+ del state_dict["pos_embed"]
361
+ del state_dict["decoder_pos_embed"]
362
  model.load_state_dict(state_dict, strict=False)
363
  print(f"Loaded checkpoint from {checkpoint}")
364
 
365
  # Running model --------------------------------------------------------------------------------
366
 
367
  model.eval()
368
+ channels = [bands.index(b) for b in ["B04", "B03", "B02"]] # BGR -> RGB
369
 
370
  # Reflect pad if not divisible by img_size
371
  original_h, original_w = input_data.shape[-2:]
372
  pad_h = img_size - (original_h % img_size)
373
  pad_w = img_size - (original_w % img_size)
374
+ input_data = np.pad(
375
+ input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
376
+ )
377
 
378
  # Build sliding window
379
+ batch = torch.tensor(input_data, device="cpu")
380
  windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
381
  h1, w1 = windows.shape[3:5]
382
+ windows = rearrange(
383
+ windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
384
+ )
385
 
386
  # Split into batches if number of windows > batch_size
387
  num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
 
399
  mask_imgs = torch.concat(mask_imgs, dim=0)
400
 
401
  # Build images from patches
402
+ rec_imgs = rearrange(
403
+ rec_imgs,
404
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
405
+ h=img_size,
406
+ w=img_size,
407
+ b=1,
408
+ c=len(bands),
409
+ t=num_frames,
410
+ h1=h1,
411
+ w1=w1,
412
+ )
413
+ mask_imgs = rearrange(
414
+ mask_imgs,
415
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
416
+ h=img_size,
417
+ w=img_size,
418
+ b=1,
419
+ c=len(bands),
420
+ t=num_frames,
421
+ h1=h1,
422
+ w1=w1,
423
+ )
424
 
425
  # Cut padded images back to original size
426
  rec_imgs_full = rec_imgs[..., :original_h, :original_w]
 
430
  # Build output images
431
  if rgb_outputs:
432
  for d in meta_data:
433
+ d.update(count=3, dtype="uint8", compress="lzw", nodata=0)
434
+
435
+ save_rgb_imgs(
436
+ batch_full[0, ...],
437
+ rec_imgs_full[0, ...],
438
+ mask_imgs_full[0, ...],
439
+ channels,
440
+ mean,
441
+ std,
442
+ output_dir,
443
+ meta_data,
444
+ )
445
  else:
446
  for d in meta_data:
447
+ d.update(compress="lzw", nodata=0)
448
 
449
+ save_imgs(
450
+ rec_imgs_full[0, ...],
451
+ mask_imgs_full[0, ...],
452
+ mean,
453
+ std,
454
+ output_dir,
455
+ meta_data,
456
+ )
457
 
458
  print("Done!")
459
 
460
 
461
  if __name__ == "__main__":
462
+ parser = argparse.ArgumentParser("MAE run inference", add_help=False)
463
+
464
+ parser.add_argument(
465
+ "--data_files",
466
+ required=True,
467
+ type=str,
468
+ nargs="+",
469
+ help="Path to the data files. Assumes multi-band files.",
470
+ )
471
+ parser.add_argument(
472
+ "--yaml_file_path",
473
+ type=str,
474
+ required=True,
475
+ help="Path to yaml file containing model training parameters.",
476
+ )
477
+ parser.add_argument(
478
+ "--checkpoint",
479
+ required=True,
480
+ type=str,
481
+ help="Path to a checkpoint file to load from.",
482
+ )
483
+ parser.add_argument(
484
+ "--output_dir",
485
+ required=True,
486
+ type=str,
487
+ help="Path to the directory where to save outputs.",
488
+ )
489
+ parser.add_argument(
490
+ "--mask_ratio",
491
+ default=None,
492
+ type=float,
493
+ help="Masking ratio (percentage of removed patches). "
494
+ "If None (default) use same value used for pretraining.",
495
+ )
496
+ parser.add_argument(
497
+ "--img_size",
498
+ default=224,
499
+ type=int,
500
+ help="Image size to be used with model. Defaults to 224",
501
+ )
502
+ parser.add_argument(
503
+ "--input_indices",
504
+ default=None,
505
+ type=int,
506
+ nargs="+",
507
+ help="0-based indices of channels to be selected from the input. By default takes all.",
508
+ )
509
+ parser.add_argument(
510
+ "--rgb_outputs",
511
+ action="store_true",
512
+ help="If present, output files will only contain RGB channels. "
513
+ "Otherwise, all bands will be saved.",
514
+ )
515
  args = parser.parse_args()
516
 
517
  main(**vars(args))
 
README.md CHANGED
@@ -36,9 +36,11 @@ The model follows the [original MAE repo](https://github.com/facebookresearch/ma
36
  There is an inference script (`Prithvi_run_inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different time steps(see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo).
37
 
38
  ```
39
- python Prithvi_run_inference.py --data_files t1.tif t2.tif t3.tif --yaml_file_path /path/to/yaml/Prithvi_100.yaml --checkpoint /path/to/checkpoint/Prithvi_100.pth --output_dir /path/to/out/dir/ --mask_ratio 0.5
40
  ```
41
 
 
 
42
  ### Finetuning examples
43
  Examples of finetuning the model for image segmentation using the mmsegmentation library are available through Hugging Face (e.g. [burn scars segmentation](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-burn-scar), [flood mapping](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11), and [multi temporal crop classification](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification)), with the code used for the experiments available on [github](https://github.com/NASA-IMPACT/hls-foundation-os/tree/main/fine-tuning-examples). This also contains instructions to finetune the model for flood detection on the popular open access [sen1floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11).
44
 
 
36
  There is an inference script (`Prithvi_run_inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different time steps(see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-demo).
37
 
38
  ```
39
+ python Prithvi_run_inference.py --data_files t1.tif t2.tif t3.tif --yaml_file_path /path/to/yaml/Prithvi_100.yaml --checkpoint /path/to/checkpoint/Prithvi_100.pth --output_dir /path/to/out/dir/ --input_indices <space separated 0-based indices of channels to select from input> --mask_ratio 0.5 --img_size <length of one side of square input shape>
40
  ```
41
 
42
+ This demo is a starting point that can be used as a starting point to generalize to different input shapes / types.
43
+
44
  ### Finetuning examples
45
  Examples of finetuning the model for image segmentation using the mmsegmentation library are available through Hugging Face (e.g. [burn scars segmentation](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-burn-scar), [flood mapping](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11), and [multi temporal crop classification](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification)), with the code used for the experiments available on [github](https://github.com/NASA-IMPACT/hls-foundation-os/tree/main/fine-tuning-examples). This also contains instructions to finetune the model for flood detection on the popular open access [sen1floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11).
46