Paolo-Fraccaro commited on
Commit
f877487
1 Parent(s): f96021c
Files changed (3) hide show
  1. Dockerfile +38 -0
  2. app.py +426 -0
  3. requirements.txt +6 -0
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python 3.9
2
+
3
+
4
+ RUN apt-get update && apt-get install --no-install-recommends -y \
5
+ build-essential \
6
+ python3.9 \
7
+ python3-pip \
8
+ git \
9
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
10
+
11
+ WORKDIR /code
12
+
13
+ COPY ./requirements.txt /code/requirements.txt
14
+
15
+ # Set up a new user named "user" with user ID 1000
16
+ RUN useradd -m -u 1000 user
17
+ # Switch to the "user" user
18
+ USER user
19
+ # Set home to the user's home directory
20
+ ENV HOME=/home/user \
21
+ PATH=/home/user/.local/bin:$PATH \
22
+ PYTHONPATH=$HOME/app \
23
+ PYTHONUNBUFFERED=1 \
24
+ GRADIO_ALLOW_FLAGGING=never \
25
+ GRADIO_NUM_PORTS=1 \
26
+ GRADIO_SERVER_NAME=0.0.0.0 \
27
+ GRADIO_THEME=huggingface \
28
+ SYSTEM=spaces
29
+
30
+ RUN pip3 install --no-cache-dir --upgrade -r /code/requirements.txt
31
+
32
+ # Set the working directory to the user's home directory
33
+ WORKDIR $HOME/app
34
+
35
+ # Copy the current directory contents into the container at $HOME/app setting the owner to the user
36
+ COPY --chown=user . $HOME/app
37
+
38
+ CMD ["python3", "app.py"]
app.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import os
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import rasterio
8
+ import torch
9
+ import yaml
10
+ from einops import rearrange
11
+
12
+ from models_mae import MaskedAutoencoderViT
13
+ import gradio as gr
14
+ from functools import partial
15
+
16
+
17
+ NO_DATA = -9999
18
+ NO_DATA_FLOAT = 0.0001
19
+ PERCENTILES = (0.1, 99.9)
20
+
21
+
22
+ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
23
+ """ Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
24
+ original range using *data_mean* and *data_std* and then lowest and highest percentiles are
25
+ removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
26
+ Args:
27
+ orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
28
+ new_img: torch.Tensor representing image with shape = (bands, H, W).
29
+ channels: list of indices representing RGB channels.
30
+ data_mean: list of mean values for each band.
31
+ data_std: list of std values for each band.
32
+ Returns:
33
+ torch.Tensor with shape (num_channels, height, width) for original image
34
+ torch.Tensor with shape (num_channels, height, width) for the other image
35
+ """
36
+
37
+ stack_c = [], []
38
+
39
+ for c in channels:
40
+ orig_ch = orig_img[c, ...]
41
+ valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
42
+ valid_mask[orig_ch == 0.0001] = False
43
+
44
+ # Back to original data range
45
+ orig_ch = (orig_ch * data_std[c]) + data_mean[c]
46
+ new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c]
47
+
48
+ # Rescale (enhancing contrast)
49
+ min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES)
50
+
51
+ orig_ch = torch.clamp((orig_ch - min_value) / (max_value - min_value), 0, 1)
52
+ new_ch = torch.clamp((new_ch - min_value) / (max_value - min_value), 0, 1)
53
+
54
+ # No data as zeros
55
+ orig_ch[~valid_mask] = 0
56
+ new_ch[~valid_mask] = 0
57
+
58
+ stack_c[0].append(orig_ch)
59
+ stack_c[1].append(new_ch)
60
+
61
+ # Channels first
62
+ stack_orig = torch.stack(stack_c[0], dim=0)
63
+ stack_rec = torch.stack(stack_c[1], dim=0)
64
+
65
+ return stack_orig, stack_rec
66
+
67
+
68
+ def read_geotiff(file_path: str):
69
+ """ Read all bands from *file_path* and returns image + meta info.
70
+ Args:
71
+ file_path: path to image file.
72
+ Returns:
73
+ np.ndarray with shape (bands, height, width)
74
+ meta info dict
75
+ """
76
+
77
+ with rasterio.open(file_path) as src:
78
+ img = src.read()
79
+ meta = src.meta
80
+
81
+ return img, meta
82
+
83
+
84
+ def save_geotiff(image, output_path: str, meta: dict):
85
+ """ Save multi-band image in Geotiff file.
86
+ Args:
87
+ image: np.ndarray with shape (bands, height, width)
88
+ output_path: path where to save the image
89
+ meta: dict with meta info.
90
+ """
91
+
92
+ with rasterio.open(output_path, "w", **meta) as dest:
93
+ for i in range(image.shape[0]):
94
+ dest.write(image[i, :, :], i + 1)
95
+
96
+ return
97
+
98
+
99
+ def _convert_np_uint8(float_image: torch.Tensor):
100
+
101
+ image = float_image.numpy() * 255.0
102
+ image = image.astype(dtype=np.uint8)
103
+ image = image.transpose((1, 2, 0))
104
+
105
+ return image
106
+
107
+
108
+ def load_example(file_paths: List[str], mean: List[float], std: List[float]):
109
+ """ Build an input example by loading images in *file_paths*.
110
+ Args:
111
+ file_paths: list of file paths .
112
+ mean: list containing mean values for each band in the images in *file_paths*.
113
+ std: list containing std values for each band in the images in *file_paths*.
114
+ Returns:
115
+ np.array containing created example
116
+ list of meta info for each image in *file_paths*
117
+ """
118
+
119
+ imgs = []
120
+ metas = []
121
+
122
+ for file in file_paths:
123
+ img, meta = read_geotiff(file)
124
+ img = img[:6]*10000
125
+
126
+ # Rescaling (don't normalize on nodata)
127
+ img = np.moveaxis(img, 0, -1) # channels last for rescaling
128
+ img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
129
+
130
+ imgs.append(img)
131
+ metas.append(meta)
132
+
133
+ imgs = np.stack(imgs, axis=0) # num_frames, img_size, img_size, C
134
+ imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, img_size, img_size
135
+ imgs = np.expand_dims(imgs, axis=0) # add batch dim
136
+
137
+ return imgs, metas
138
+
139
+
140
+ def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
141
+ """ Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
142
+ Args:
143
+ model: MAE model to run.
144
+ input_data: torch.Tensor with shape (B, C, T, H, W).
145
+ mask_ratio: mask ratio to use.
146
+ device: device where model should run.
147
+ Returns:
148
+ 3 torch.Tensor with shape (B, C, T, H, W).
149
+ """
150
+
151
+ with torch.no_grad():
152
+ x = input_data.to(device)
153
+
154
+ _, pred, mask = model(x, mask_ratio)
155
+
156
+ # Create mask and prediction images (un-patchify)
157
+ mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
158
+ pred_img = model.unpatchify(pred).detach().cpu()
159
+
160
+ # Mix visible and predicted patches
161
+ rec_img = input_data.clone()
162
+ rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove
163
+
164
+ # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
165
+ mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
166
+
167
+ return rec_img, mask_img
168
+
169
+
170
+ def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
171
+ """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
172
+ Args:
173
+ input_img: input torch.Tensor with shape (C, T, H, W).
174
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
175
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
176
+ channels: list of indices representing RGB channels.
177
+ mean: list of mean values for each band.
178
+ std: list of std values for each band.
179
+ output_dir: directory where to save outputs.
180
+ meta_data: list of dicts with geotiff meta info.
181
+ """
182
+
183
+ for t in range(input_img.shape[1]):
184
+ rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
185
+ new_img=rec_img[:, t, :, :],
186
+ channels=channels, data_mean=mean,
187
+ data_std=std)
188
+
189
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
190
+
191
+ # Saving images
192
+
193
+ save_geotiff(image=_convert_np_uint8(rgb_orig),
194
+ output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
195
+ meta=meta_data[t])
196
+
197
+ save_geotiff(image=_convert_np_uint8(rgb_pred),
198
+ output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
199
+ meta=meta_data[t])
200
+
201
+ save_geotiff(image=_convert_np_uint8(rgb_mask),
202
+ output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
203
+ meta=meta_data[t])
204
+
205
+
206
+ def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
207
+ """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
208
+ Args:
209
+ input_img: input torch.Tensor with shape (C, T, H, W).
210
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
211
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
212
+ channels: list of indices representing RGB channels.
213
+ mean: list of mean values for each band.
214
+ std: list of std values for each band.
215
+ output_dir: directory where to save outputs.
216
+ meta_data: list of dicts with geotiff meta info.
217
+ """
218
+ rgb_orig_list = []
219
+ rgb_mask_list = []
220
+ rgb_pred_list = []
221
+
222
+ for t in range(input_img.shape[1]):
223
+ rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
224
+ new_img=rec_img[:, t, :, :],
225
+ channels=channels, data_mean=mean,
226
+ data_std=std)
227
+
228
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
229
+
230
+ # extract images
231
+ rgb_orig_list.append(_convert_np_uint8(rgb_orig))
232
+ rgb_mask_list.append(_convert_np_uint8(rgb_mask))
233
+ rgb_pred_list.append(_convert_np_uint8(rgb_pred))
234
+
235
+ outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
236
+
237
+ return outputs
238
+
239
+
240
+ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str, checkpoint: str):
241
+
242
+ # os.makedirs(output_dir, exist_ok=True)
243
+
244
+ # Get parameters --------
245
+
246
+ with open(yaml_file_path, 'r') as f:
247
+ params = yaml.safe_load(f)
248
+
249
+ # data related
250
+ num_frames = params['num_frames']
251
+ img_size = params['img_size']
252
+ bands = params['bands']
253
+ mean = params['data_mean']
254
+ std = params['data_std']
255
+
256
+ # model related
257
+ depth = params['depth']
258
+ patch_size = params['patch_size']
259
+ embed_dim = params['embed_dim']
260
+ num_heads = params['num_heads']
261
+ tubelet_size = params['tubelet_size']
262
+ decoder_embed_dim = params['decoder_embed_dim']
263
+ decoder_num_heads = params['decoder_num_heads']
264
+ decoder_depth = params['decoder_depth']
265
+
266
+ batch_size = params['batch_size']
267
+
268
+ mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
269
+
270
+ # We must have *num_frames* files to build one example!
271
+ assert len(data_files) == num_frames, "File list must be equal to expected number of frames."
272
+
273
+ if torch.cuda.is_available():
274
+ device = torch.device('cuda')
275
+ else:
276
+ device = torch.device('cpu')
277
+
278
+ print(f"Using {device} device.\n")
279
+
280
+ # Loading data ---------------------------------------------------------------------------------
281
+
282
+ input_data, meta_data = load_example(file_paths=data_files, mean=mean, std=std)
283
+
284
+ # Create model and load checkpoint -------------------------------------------------------------
285
+
286
+ model = MaskedAutoencoderViT(
287
+ img_size=img_size,
288
+ patch_size=patch_size,
289
+ num_frames=num_frames,
290
+ tubelet_size=tubelet_size,
291
+ in_chans=len(bands),
292
+ embed_dim=embed_dim,
293
+ depth=depth,
294
+ num_heads=num_heads,
295
+ decoder_embed_dim=decoder_embed_dim,
296
+ decoder_depth=decoder_depth,
297
+ decoder_num_heads=decoder_num_heads,
298
+ mlp_ratio=4.,
299
+ norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
300
+ norm_pix_loss=False)
301
+
302
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
303
+ print(f"\n--> model has {total_params / 1e6} Million params.\n")
304
+
305
+ model.to(device)
306
+
307
+ state_dict = torch.load(checkpoint, map_location=device)
308
+ model.load_state_dict(state_dict)
309
+ print(f"Loaded checkpoint from {checkpoint}")
310
+
311
+ # Running model --------------------------------------------------------------------------------
312
+
313
+ model.eval()
314
+ channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
315
+
316
+ # Build sliding window
317
+ batch = torch.tensor(input_data, device='cpu')
318
+ windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
319
+ h1, w1 = windows.shape[3:5]
320
+ windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size)
321
+
322
+ # Split into batches if number of windows > batch_size
323
+ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
324
+ windows = torch.tensor_split(windows, num_batches, dim=0)
325
+
326
+ # Run model
327
+ rec_imgs = []
328
+ mask_imgs = []
329
+ for x in windows:
330
+ rec_img, mask_img = run_model(model, x, mask_ratio, device)
331
+ rec_imgs.append(rec_img)
332
+ mask_imgs.append(mask_img)
333
+
334
+ rec_imgs = torch.concat(rec_imgs, dim=0)
335
+ mask_imgs = torch.concat(mask_imgs, dim=0)
336
+
337
+ # Build images from patches
338
+ rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
339
+ h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
340
+ mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
341
+ h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
342
+
343
+ # Mix original image with patches
344
+ h, w = rec_imgs.shape[-2:]
345
+ rec_imgs_full = batch.clone()
346
+ rec_imgs_full[..., :h, :w] = rec_imgs
347
+
348
+ mask_imgs_full = torch.ones_like(batch)
349
+ mask_imgs_full[..., :h, :w] = mask_imgs
350
+
351
+ # Build RGB images
352
+ for d in meta_data:
353
+ d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
354
+
355
+ # save_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
356
+ # channels, mean, std, output_dir, meta_data)
357
+
358
+ outputs = extract_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
359
+ channels, mean, std)
360
+
361
+
362
+ print("Done!")
363
+
364
+ return outputs
365
+
366
+ from huggingface_hub import hf_hub_download
367
+
368
+ yaml_file_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename="Prithvi_100M_config.yaml")
369
+ checkpoint=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename='Prithvi_100M.pt')
370
+
371
+ func = partial(predict_on_images, yaml_file_path=yaml_file_path,checkpoint=checkpoint)
372
+
373
+
374
+ with gr.Blocks() as demo:
375
+
376
+ with gr.Row():
377
+ with gr.Column():
378
+ inp_files = gr.Files(elem_id='files')
379
+ # inp_slider = gr.Slider(0, 100, value=50, label="Mask ratio", info="Choose ratio of masking between 0 and 100", elem_id='slider'),
380
+ btn = gr.Button("Submit")
381
+ with gr.Row():
382
+ gr.Markdown(value='Original images')
383
+ with gr.Row():
384
+ gr.Markdown(value='T1')
385
+ gr.Markdown(value='T2')
386
+ gr.Markdown(value='T3')
387
+ with gr.Row():
388
+ out1_orig_t1=gr.Image(image_mode='RGB')
389
+ out2_orig_t2 = gr.Image(image_mode='RGB')
390
+ out3_orig_t3 = gr.Image(image_mode='RGB')
391
+ with gr.Row():
392
+ gr.Markdown(value='Masked images')
393
+ with gr.Row():
394
+ gr.Markdown(value='T1')
395
+ gr.Markdown(value='T2')
396
+ gr.Markdown(value='T3')
397
+ with gr.Row():
398
+ out4_masked_t1=gr.Image(image_mode='RGB')
399
+ out5_masked_t2 = gr.Image(image_mode='RGB')
400
+ out6_masked_t3 = gr.Image(image_mode='RGB')
401
+ with gr.Row():
402
+ gr.Markdown(value='Reonstructed images')
403
+ with gr.Row():
404
+ gr.Markdown(value='T1')
405
+ gr.Markdown(value='T2')
406
+ gr.Markdown(value='T3')
407
+ with gr.Row():
408
+ out7_pred_t1=gr.Image(image_mode='RGB')
409
+ out8_pred_t2 = gr.Image(image_mode='RGB')
410
+ out9_pred_t3 = gr.Image(image_mode='RGB')
411
+
412
+
413
+ btn.click(fn=func,
414
+ # inputs=[inp_files, inp_slider],
415
+ inputs=inp_files,
416
+ outputs=[out1_orig_t1,
417
+ out2_orig_t2,
418
+ out3_orig_t3,
419
+ out4_masked_t1,
420
+ out5_masked_t2,
421
+ out6_masked_t3,
422
+ out7_pred_t1,
423
+ out8_pred_t2,
424
+ out9_pred_t3])
425
+
426
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ rasterio
5
+ einops
6
+ huggingface_hub