blumenstiel commited on
Commit
d668197
Β·
1 Parent(s): 5089ae8

Update app

Browse files
Files changed (10) hide show
  1. app.py +117 -316
  2. HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +0 -0
  3. HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +0 -0
  4. HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +0 -0
  5. HLS.L30.T17RMP.2018004T155509.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T17RMP.2018004T155509.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +0 -0
  6. HLS.L30.T17RMP.2018036T155452.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T17RMP.2018036T155452.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +0 -0
  7. HLS.L30.T17RMP.2018068T155438.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T17RMP.2018068T155438.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +0 -0
  8. HLS.L30.T18TVL.2018029T154533.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T18TVL.2018029T154533.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +0 -0
  9. HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +0 -0
  10. HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +0 -0
app.py CHANGED
@@ -1,215 +1,28 @@
1
- #### pull files from hub
2
- from huggingface_hub import hf_hub_download
3
- import os
4
- yaml_file_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename="Prithvi_100M_config.yaml", token=os.environ.get("token"))
5
- checkpoint=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename='Prithvi_100M.pt', token=os.environ.get("token"))
6
- model_def=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename='Prithvi.py', token=os.environ.get("token"))
7
- os.system(f'cp {model_def} .')
8
- #####
9
- import argparse
10
- import functools
11
- import os
12
- from typing import List
13
 
14
- import numpy as np
15
- import rasterio
16
  import torch
17
  import yaml
18
- from einops import rearrange
19
-
20
- from Prithvi import MaskedAutoencoderViT
21
  import gradio as gr
 
22
  from functools import partial
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
- NO_DATA = -9999
26
- NO_DATA_FLOAT = 0.0001
27
- PERCENTILES = (0.1, 99.9)
28
-
29
-
30
- def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
31
- """ Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
32
- original range using *data_mean* and *data_std* and then lowest and highest percentiles are
33
- removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
34
- Args:
35
- orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
36
- new_img: torch.Tensor representing image with shape = (bands, H, W).
37
- channels: list of indices representing RGB channels.
38
- data_mean: list of mean values for each band.
39
- data_std: list of std values for each band.
40
- Returns:
41
- torch.Tensor with shape (num_channels, height, width) for original image
42
- torch.Tensor with shape (num_channels, height, width) for the other image
43
- """
44
-
45
- stack_c = [], []
46
-
47
- for c in channels:
48
- orig_ch = orig_img[c, ...]
49
- valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
50
- valid_mask[orig_ch == NO_DATA_FLOAT] = False
51
-
52
- # Back to original data range
53
- orig_ch = (orig_ch * data_std[c]) + data_mean[c]
54
- new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c]
55
-
56
- # Rescale (enhancing contrast)
57
- min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES)
58
-
59
- orig_ch = torch.clamp((orig_ch - min_value) / (max_value - min_value), 0, 1)
60
- new_ch = torch.clamp((new_ch - min_value) / (max_value - min_value), 0, 1)
61
-
62
- # No data as zeros
63
- orig_ch[~valid_mask] = 0
64
- new_ch[~valid_mask] = 0
65
-
66
- stack_c[0].append(orig_ch)
67
- stack_c[1].append(new_ch)
68
-
69
- # Channels first
70
- stack_orig = torch.stack(stack_c[0], dim=0)
71
- stack_rec = torch.stack(stack_c[1], dim=0)
72
-
73
- return stack_orig, stack_rec
74
-
75
-
76
- def read_geotiff(file_path: str):
77
- """ Read all bands from *file_path* and returns image + meta info.
78
- Args:
79
- file_path: path to image file.
80
- Returns:
81
- np.ndarray with shape (bands, height, width)
82
- meta info dict
83
- """
84
-
85
- with rasterio.open(file_path) as src:
86
- img = src.read()
87
- meta = src.meta
88
-
89
- return img, meta
90
-
91
-
92
- def save_geotiff(image, output_path: str, meta: dict):
93
- """ Save multi-band image in Geotiff file.
94
- Args:
95
- image: np.ndarray with shape (bands, height, width)
96
- output_path: path where to save the image
97
- meta: dict with meta info.
98
- """
99
-
100
- with rasterio.open(output_path, "w", **meta) as dest:
101
- for i in range(image.shape[0]):
102
- dest.write(image[i, :, :], i + 1)
103
-
104
- return
105
-
106
-
107
- def _convert_np_uint8(float_image: torch.Tensor):
108
-
109
- image = float_image.numpy() * 255.0
110
- image = image.astype(dtype=np.uint8)
111
- image = image.transpose((1, 2, 0))
112
-
113
- return image
114
-
115
-
116
- def load_example(file_paths: List[str], mean: List[float], std: List[float]):
117
- """ Build an input example by loading images in *file_paths*.
118
- Args:
119
- file_paths: list of file paths .
120
- mean: list containing mean values for each band in the images in *file_paths*.
121
- std: list containing std values for each band in the images in *file_paths*.
122
- Returns:
123
- np.array containing created example
124
- list of meta info for each image in *file_paths*
125
- """
126
-
127
- imgs = []
128
- metas = []
129
-
130
- for file in file_paths:
131
- img, meta = read_geotiff(file)
132
- img = img[:6]*10000 if img[:6].mean() <= 2 else img[:6]
133
-
134
- # Rescaling (don't normalize on nodata)
135
- img = np.moveaxis(img, 0, -1) # channels last for rescaling
136
- img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
137
-
138
- imgs.append(img)
139
- metas.append(meta)
140
-
141
- imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
142
- imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, H, W
143
- imgs = np.expand_dims(imgs, axis=0) # add batch dim
144
-
145
- return imgs, metas
146
-
147
-
148
- def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
149
- """ Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
150
- Args:
151
- model: MAE model to run.
152
- input_data: torch.Tensor with shape (B, C, T, H, W).
153
- mask_ratio: mask ratio to use.
154
- device: device where model should run.
155
- Returns:
156
- 3 torch.Tensor with shape (B, C, T, H, W).
157
- """
158
-
159
- with torch.no_grad():
160
- x = input_data.to(device)
161
-
162
- _, pred, mask = model(x, mask_ratio)
163
-
164
- # Create mask and prediction images (un-patchify)
165
- mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
166
- pred_img = model.unpatchify(pred).detach().cpu()
167
-
168
- # Mix visible and predicted patches
169
- rec_img = input_data.clone()
170
- rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove
171
-
172
- # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
173
- mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
174
-
175
- return rec_img, mask_img
176
-
177
-
178
- def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
179
- """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
180
- Args:
181
- input_img: input torch.Tensor with shape (C, T, H, W).
182
- rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
183
- mask_img: mask torch.Tensor with shape (C, T, H, W).
184
- channels: list of indices representing RGB channels.
185
- mean: list of mean values for each band.
186
- std: list of std values for each band.
187
- output_dir: directory where to save outputs.
188
- meta_data: list of dicts with geotiff meta info.
189
- """
190
-
191
- for t in range(input_img.shape[1]):
192
- rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
193
- new_img=rec_img[:, t, :, :],
194
- channels=channels, data_mean=mean,
195
- data_std=std)
196
-
197
- rgb_mask = mask_img[channels, t, :, :] * rgb_orig
198
-
199
- # Saving images
200
-
201
- save_geotiff(image=_convert_np_uint8(rgb_orig),
202
- output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
203
- meta=meta_data[t])
204
-
205
- save_geotiff(image=_convert_np_uint8(rgb_pred),
206
- output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
207
- meta=meta_data[t])
208
-
209
- save_geotiff(image=_convert_np_uint8(rgb_mask),
210
- output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
211
- meta=meta_data[t])
212
-
213
 
214
  def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
215
  """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
@@ -230,24 +43,31 @@ def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
230
  for t in range(input_img.shape[1]):
231
  rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
232
  new_img=rec_img[:, t, :, :],
233
- channels=channels, data_mean=mean,
234
- data_std=std)
 
235
 
236
  rgb_mask = mask_img[channels, t, :, :] * rgb_orig
237
 
238
  # extract images
239
- rgb_orig_list.append(_convert_np_uint8(rgb_orig))
240
- rgb_mask_list.append(_convert_np_uint8(rgb_mask))
241
- rgb_pred_list.append(_convert_np_uint8(rgb_pred))
242
-
 
 
 
 
 
 
 
 
243
  outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
244
 
245
  return outputs
246
 
247
 
248
- def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str, checkpoint: str):
249
-
250
-
251
  try:
252
  data_files = [x.name for x in data_files]
253
  print('Path extracted from example')
@@ -257,24 +77,18 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
257
  # Get parameters --------
258
  print('This is the printout', data_files)
259
 
260
- with open(yaml_file_path, 'r') as f:
261
- params = yaml.safe_load(f)
262
-
263
- model_params = params["model_args"]
264
- # data related
265
- train_params = params["train_params"]
266
- num_frames = model_params['num_frames']
267
- img_size = model_params['img_size']
268
- bands = train_params['bands']
269
- mean = train_params['data_mean']
270
- std = train_params['data_std']
271
 
272
  batch_size = 8
 
 
 
 
 
 
273
 
274
- mask_ratio = train_params['mask_ratio'] if mask_ratio is None else mask_ratio
275
-
276
- # We must have *num_frames* files to build one example!
277
- assert len(data_files) == num_frames, "File list must be equal to expected number of frames."
278
 
279
  if torch.cuda.is_available():
280
  device = torch.device('cuda')
@@ -289,16 +103,23 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
289
 
290
  # Create model and load checkpoint -------------------------------------------------------------
291
 
292
- model = MaskedAutoencoderViT(
293
- **model_params)
 
 
 
294
 
295
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
296
  print(f"\n--> Model has {total_params:,} parameters.\n")
297
 
298
  model.to(device)
299
 
300
- state_dict = torch.load(checkpoint, map_location=device)
301
- model.load_state_dict(state_dict)
 
 
 
 
302
  print(f"Loaded checkpoint from {checkpoint}")
303
 
304
  # Running model --------------------------------------------------------------------------------
@@ -348,37 +169,35 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
348
  for d in meta_data:
349
  d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
350
 
351
- # save_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
352
- # channels, mean, std, output_dir, meta_data)
353
-
354
  outputs = extract_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
355
  channels, mean, std)
356
 
357
-
358
  print("Done!")
359
 
360
  return outputs
361
 
362
 
 
363
 
 
364
 
365
- func = partial(predict_on_images, yaml_file_path=yaml_file_path,checkpoint=checkpoint)
 
 
366
 
367
- def preprocess_example(example_list):
368
- print('######## preprocessing here ##########')
369
- example_list = [os.path.join(os.path.abspath(''), x) for x in example_list]
370
-
371
- return example_list
372
-
373
 
374
- with gr.Blocks() as demo:
375
-
376
- gr.Markdown(value='# Prithvi image reconstruction demo')
377
- gr.Markdown(value='''Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data. Particularly, the model adopts a self-supervised encoder developed with a ViT architecture and Masked AutoEncoder learning strategy, with a MSE as a loss function. The model includes spatial attention across multiple patchies and also temporal attention for each patch. More info about the model and its weights are available [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M).\n
378
- This demo showcases the image reconstracting over three timestamps, with the user providing a set of three HLS images and the model randomly masking out some proportion of the images and then reconstructing them based on the not masked portion of the images.\n
379
- The user needs to provide three HLS geotiff images, including the following channels in reflectance units: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
380
 
381
- Check out our newest model: [Prithvi-EO-2.0-Demo](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-EO-2.0-Demo).
 
382
  ''')
383
  with gr.Row():
384
  with gr.Column():
@@ -386,73 +205,55 @@ Check out our newest model: [Prithvi-EO-2.0-Demo](https://huggingface.co/spaces/
386
  # inp_slider = gr.Slider(0, 100, value=50, label="Mask ratio", info="Choose ratio of masking between 0 and 100", elem_id='slider'),
387
  btn = gr.Button("Submit")
388
  with gr.Row():
389
- gr.Markdown(value='## Original images')
390
- with gr.Row():
391
- gr.Markdown(value='T1')
392
- gr.Markdown(value='T2')
393
- gr.Markdown(value='T3')
394
- with gr.Row():
395
- out1_orig_t1=gr.Image(image_mode='RGB')
396
- out2_orig_t2 = gr.Image(image_mode='RGB')
397
- out3_orig_t3 = gr.Image(image_mode='RGB')
398
- with gr.Row():
399
  gr.Markdown(value='## Masked images')
400
- with gr.Row():
401
- gr.Markdown(value='T1')
402
- gr.Markdown(value='T2')
403
- gr.Markdown(value='T3')
404
- with gr.Row():
405
- out4_masked_t1=gr.Image(image_mode='RGB')
406
- out5_masked_t2 = gr.Image(image_mode='RGB')
407
- out6_masked_t3 = gr.Image(image_mode='RGB')
408
- with gr.Row():
409
- gr.Markdown(value='## Reonstructed images')
410
- with gr.Row():
411
- gr.Markdown(value='T1')
412
- gr.Markdown(value='T2')
413
- gr.Markdown(value='T3')
414
- with gr.Row():
415
- out7_pred_t1=gr.Image(image_mode='RGB')
416
- out8_pred_t2 = gr.Image(image_mode='RGB')
417
- out9_pred_t3 = gr.Image(image_mode='RGB')
418
-
419
-
420
- btn.click(fn=func,
421
- # inputs=[inp_files, inp_slider],
422
  inputs=inp_files,
423
- outputs=[out1_orig_t1,
424
- out2_orig_t2,
425
- out3_orig_t3,
426
- out4_masked_t1,
427
- out5_masked_t2,
428
- out6_masked_t3,
429
- out7_pred_t1,
430
- out8_pred_t2,
431
- out9_pred_t3])
432
  with gr.Row():
433
- gr.Examples(examples=[[[os.path.join(os.path.dirname(__file__), "HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
434
- os.path.join(os.path.dirname(__file__), "HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
435
- os.path.join(os.path.dirname(__file__), "HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")]],
436
- [[os.path.join(os.path.dirname(__file__), "HLS.L30.T17RMP.2018004T155509.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
437
- os.path.join(os.path.dirname(__file__), "HLS.L30.T17RMP.2018036T155452.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
438
- os.path.join(os.path.dirname(__file__), "HLS.L30.T17RMP.2018068T155438.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")]],
439
- [[os.path.join(os.path.dirname(__file__), "HLS.L30.T18TVL.2018029T154533.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
440
- os.path.join(os.path.dirname(__file__), "HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
441
- os.path.join(os.path.dirname(__file__), "HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")]]],
442
- inputs=inp_files,
443
- outputs=[out1_orig_t1,
444
- out2_orig_t2,
445
- out3_orig_t3,
446
- out4_masked_t1,
447
- out5_masked_t2,
448
- out6_masked_t3,
449
- out7_pred_t1,
450
- out8_pred_t2,
451
- out9_pred_t3],
452
- # preprocess=preprocess_example,
453
- fn=func,
454
- cache_examples=True
455
  )
456
-
457
 
458
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
+ import os
 
3
  import torch
4
  import yaml
5
+ import numpy as np
 
 
6
  import gradio as gr
7
+ from einops import rearrange
8
  from functools import partial
9
+ from huggingface_hub import hf_hub_download
10
 
11
+ # pull files from hub
12
+ token = os.environ.get("HF_TOKEN", None)
13
+ config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-1.0-100M",
14
+ filename="config.json", token=token)
15
+ checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-1.0-100M",
16
+ filename='Prithvi_EO_V1_100M.pt', token=token)
17
+ model_def = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-1.0-100M",
18
+ filename='prithvi_mae.py', token=token)
19
+ model_inference = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-1.0-100M",
20
+ filename='inference.py', token=token)
21
+ os.system(f'cp {model_def} .')
22
+ os.system(f'cp {model_inference} .')
23
 
24
+ from prithvi_mae import PrithviMAE
25
+ from inference import process_channel_group, _convert_np_uint8, load_example, run_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
28
  """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
 
43
  for t in range(input_img.shape[1]):
44
  rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
45
  new_img=rec_img[:, t, :, :],
46
+ channels=channels,
47
+ mean=mean,
48
+ std=std)
49
 
50
  rgb_mask = mask_img[channels, t, :, :] * rgb_orig
51
 
52
  # extract images
53
+ rgb_orig_list.append(_convert_np_uint8(rgb_orig).transpose(1, 2, 0))
54
+ rgb_mask_list.append(_convert_np_uint8(rgb_mask).transpose(1, 2, 0))
55
+ rgb_pred_list.append(_convert_np_uint8(rgb_pred).transpose(1, 2, 0))
56
+
57
+ # Add white dummy image values for missing timestamps
58
+ dummy = np.ones((20, 20), dtype=np.uint8) * 255
59
+ num_dummies = 3 - len(rgb_orig_list)
60
+ if num_dummies:
61
+ rgb_orig_list.extend([dummy] * num_dummies)
62
+ rgb_mask_list.extend([dummy] * num_dummies)
63
+ rgb_pred_list.extend([dummy] * num_dummies)
64
+
65
  outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
66
 
67
  return outputs
68
 
69
 
70
+ def predict_on_images(data_files: list, config_path: str, checkpoint: str, mask_ratio: float = None):
 
 
71
  try:
72
  data_files = [x.name for x in data_files]
73
  print('Path extracted from example')
 
77
  # Get parameters --------
78
  print('This is the printout', data_files)
79
 
80
+ with open(config_path, 'r') as f:
81
+ config = yaml.safe_load(f)['pretrained_cfg']
 
 
 
 
 
 
 
 
 
82
 
83
  batch_size = 8
84
+ bands = config['bands']
85
+ num_frames = len(data_files)
86
+ mean = config['mean']
87
+ std = config['std']
88
+ img_size = config['img_size']
89
+ mask_ratio = mask_ratio or config['mask_ratio']
90
 
91
+ assert num_frames <= 3, "Demo only supports up to three timestamps"
 
 
 
92
 
93
  if torch.cuda.is_available():
94
  device = torch.device('cuda')
 
103
 
104
  # Create model and load checkpoint -------------------------------------------------------------
105
 
106
+ config.update(
107
+ num_frames=num_frames,
108
+ )
109
+
110
+ model = PrithviMAE(**config)
111
 
112
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
113
  print(f"\n--> Model has {total_params:,} parameters.\n")
114
 
115
  model.to(device)
116
 
117
+ state_dict = torch.load(checkpoint, map_location=device, weights_only=False)
118
+ # discard fixed pos_embedding weight
119
+ for k in list(state_dict.keys()):
120
+ if 'pos_embed' in k:
121
+ del state_dict[k]
122
+ model.load_state_dict(state_dict, strict=False)
123
  print(f"Loaded checkpoint from {checkpoint}")
124
 
125
  # Running model --------------------------------------------------------------------------------
 
169
  for d in meta_data:
170
  d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
171
 
 
 
 
172
  outputs = extract_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
173
  channels, mean, std)
174
 
 
175
  print("Done!")
176
 
177
  return outputs
178
 
179
 
180
+ run_inference = partial(predict_on_images, config_path=config_path,checkpoint=checkpoint)
181
 
182
+ with gr.Blocks() as demo:
183
 
184
+ gr.Markdown(value='# Prithvi-EO-1.0 image reconstruction demo')
185
+ gr.Markdown(value='''
186
+ Check out our newest model: [Prithvi-EO-2.0-Demo](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-EO-2.0-Demo).
187
 
188
+ Prithvi is a first-of-its-kind temporal Vision transformer pretrained by the IBM and NASA team on continental US Harmonised Landsat Sentinel 2 (HLS) data.
189
+ Particularly, the model adopts a self-supervised encoder developed with a ViT architecture and Masked AutoEncoder learning strategy, with a MSE as a loss function.
190
+ The model includes spatial attention across multiple patchies and also temporal attention for each patch.
191
+ More info about the model and its weights are available [here](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M).\n
 
 
192
 
193
+ This demo showcases the image reconstruction over one to three timestamps.
194
+ The model randomly masks out some proportion of the images and reconstructs them based on the not masked portion of the images.
195
+ The reconstructed images are merged with the visible unmasked patches.
196
+ We recommend submitting images of size 224 to ~1000 pixels for faster processing time.
197
+ Images bigger than 224x224 are processed using a sliding window approach which can lead to artefacts between patches.\n
 
198
 
199
+ The user needs to provide the HLS geotiff images, including the following channels in reflectance units: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
200
+ Some example images are provided at the end of this page.
201
  ''')
202
  with gr.Row():
203
  with gr.Column():
 
205
  # inp_slider = gr.Slider(0, 100, value=50, label="Mask ratio", info="Choose ratio of masking between 0 and 100", elem_id='slider'),
206
  btn = gr.Button("Submit")
207
  with gr.Row():
208
+ gr.Markdown(value='## Input time series')
 
 
 
 
 
 
 
 
 
209
  gr.Markdown(value='## Masked images')
210
+ gr.Markdown(value='## Reconstructed images*')
211
+
212
+ original = []
213
+ masked = []
214
+ predicted = []
215
+ timestamps = []
216
+ for t in range(3):
217
+ timestamps.append(gr.Column(visible=t == 0))
218
+ with timestamps[t]:
219
+ #with gr.Row():
220
+ # gr.Markdown(value=f"Timestamp {t+1}")
221
+ with gr.Row():
222
+ original.append(gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False))
223
+ masked.append(gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False))
224
+ predicted.append(gr.Image(image_mode='RGB', show_label=False, show_fullscreen_button=False))
225
+
226
+ gr.Markdown(value='\* The reconstructed images include the ground truth unmasked patches.')
227
+
228
+ btn.click(fn=run_inference,
 
 
 
229
  inputs=inp_files,
230
+ outputs=original + masked + predicted)
231
+
 
 
 
 
 
 
 
232
  with gr.Row():
233
+ gr.Examples(examples=[[[
234
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
235
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
236
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")
237
+ ]],[[
238
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T17RMP.2018004T155509.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
239
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T17RMP.2018036T155452.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
240
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T17RMP.2018068T155438.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")
241
+ ]],[[
242
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018029T154533.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
243
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"),
244
+ os.path.join(os.path.dirname(__file__), "examples/HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif")
245
+ ]]],
246
+ inputs=inp_files,
247
+ outputs=original + masked + predicted,
248
+ fn=run_inference,
249
+ cache_examples=True
 
 
 
 
 
250
  )
 
251
 
252
+ def update_visibility(files):
253
+ timestamps = [gr.Column(visible=t < len(files)) for t in range(3)]
254
+
255
+ return timestamps
256
+
257
+ inp_files.change(update_visibility, inp_files, timestamps)
258
+
259
+ demo.launch() # share=True, ssr_mode=False
HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif RENAMED
File without changes
HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif RENAMED
File without changes
HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif RENAMED
File without changes
HLS.L30.T17RMP.2018004T155509.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T17RMP.2018004T155509.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif RENAMED
File without changes
HLS.L30.T17RMP.2018036T155452.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T17RMP.2018036T155452.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif RENAMED
File without changes
HLS.L30.T17RMP.2018068T155438.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T17RMP.2018068T155438.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif RENAMED
File without changes
HLS.L30.T18TVL.2018029T154533.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T18TVL.2018029T154533.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif RENAMED
File without changes
HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T18TVL.2018141T154435.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif RENAMED
File without changes
HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif β†’ examples/HLS.L30.T18TVL.2018189T154446.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif RENAMED
File without changes