jonathan-glider commited on
Commit
ced6b43
1 Parent(s): c225b3b

Upload 14 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ first.tif filter=lfs diff=lfs merge=lfs -text
37
+ second.tif filter=lfs diff=lfs merge=lfs -text
38
+ streamlit-testing.webm filter=lfs diff=lfs merge=lfs -text
39
+ temp/1ab7b057-9543-4240-92a1-f85bba853af6.jpg filter=lfs diff=lfs merge=lfs -text
40
+ temp/771eca76-489a-41c7-bcfb-28b841e78dd7.jpg filter=lfs diff=lfs merge=lfs -text
41
+ third.tif filter=lfs diff=lfs merge=lfs -text
Korea_data.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9384fd7b5ee9e5aaf80c7251469e0fd68925b5cb8e1a5fabea8fd2cd8bb7c9bd
3
+ size 483473233
Prithvi.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # --------------------------------------------------------
11
+
12
+ from functools import partial
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ from timm.models.vision_transformer import Block
18
+ from timm.models.layers import to_2tuple
19
+
20
+ import numpy as np
21
+
22
+ from einops import rearrange
23
+
24
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
25
+ """
26
+ embed_dim: output dimension for each position
27
+ pos: a list of positions to be encoded: size (M,)
28
+ out: (M, D)
29
+ """
30
+ assert embed_dim % 2 == 0
31
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
32
+ omega /= embed_dim / 2.
33
+ omega = 1. / 10000**omega # (D/2,)
34
+
35
+ pos = pos.reshape(-1) # (M,)
36
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
37
+
38
+ emb_sin = np.sin(out) # (M, D/2)
39
+ emb_cos = np.cos(out) # (M, D/2)
40
+
41
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
42
+ return emb
43
+
44
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
45
+ assert embed_dim % 2 == 0
46
+
47
+ # use half of dimensions to encode grid_h
48
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
49
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
50
+
51
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
52
+ return emb
53
+
54
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
55
+ """
56
+ grid_size: 3d tuple of grid size: t, h, w
57
+ return:
58
+ pos_embed: L, D
59
+ """
60
+
61
+ assert embed_dim % 16 == 0
62
+
63
+ t_size, h_size, w_size = grid_size
64
+
65
+ w_embed_dim = embed_dim // 16 * 6
66
+ h_embed_dim = embed_dim // 16 * 6
67
+ t_embed_dim = embed_dim // 16 * 4
68
+
69
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
70
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
71
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
72
+
73
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
74
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
75
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
76
+
77
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
78
+
79
+ if cls_token:
80
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
81
+ return pos_embed
82
+
83
+
84
+ class PatchEmbed(nn.Module):
85
+ """ Frames of 2D Images to Patch Embedding
86
+ The 3D version of timm.models.vision_transformer.PatchEmbed
87
+ """
88
+ def __init__(
89
+ self,
90
+ img_size=224,
91
+ patch_size=16,
92
+ num_frames=3,
93
+ tubelet_size=1,
94
+ in_chans=3,
95
+ embed_dim=768,
96
+ norm_layer=None,
97
+ flatten=True,
98
+ bias=True,
99
+ ):
100
+ super().__init__()
101
+ img_size = to_2tuple(img_size)
102
+ patch_size = to_2tuple(patch_size)
103
+ self.img_size = img_size
104
+ self.patch_size = patch_size
105
+ self.num_frames = num_frames
106
+ self.tubelet_size = tubelet_size
107
+ self.grid_size = (num_frames // tubelet_size, img_size[0] // patch_size[0], img_size[1] // patch_size[1])
108
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
109
+ self.flatten = flatten
110
+
111
+ self.proj = nn.Conv3d(in_chans, embed_dim,
112
+ kernel_size=(tubelet_size, patch_size[0], patch_size[1]),
113
+ stride=(tubelet_size, patch_size[0], patch_size[1]), bias=bias)
114
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
115
+
116
+ def forward(self, x):
117
+ B, C, T, H, W = x.shape
118
+ x = self.proj(x)
119
+ if self.flatten:
120
+ x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
121
+ x = self.norm(x)
122
+ return x
123
+
124
+
125
+ class MaskedAutoencoderViT(nn.Module):
126
+ """ Masked Autoencoder with VisionTransformer backbone
127
+ """
128
+ def __init__(self, img_size=224, patch_size=16,
129
+ num_frames=3, tubelet_size=1,
130
+ in_chans=3, embed_dim=1024, depth=24, num_heads=16,
131
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
132
+ mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False):
133
+ super().__init__()
134
+
135
+ # --------------------------------------------------------------------------
136
+ # MAE encoder specifics
137
+ self.patch_embed = PatchEmbed(img_size, patch_size,num_frames, tubelet_size, in_chans, embed_dim)
138
+ num_patches = self.patch_embed.num_patches
139
+
140
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
141
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
142
+
143
+ self.blocks = nn.ModuleList([
144
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
145
+ for i in range(depth)])
146
+ self.norm = norm_layer(embed_dim)
147
+ # --------------------------------------------------------------------------
148
+
149
+ # --------------------------------------------------------------------------
150
+ # MAE decoder specifics
151
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
152
+
153
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
154
+
155
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
156
+
157
+ self.decoder_blocks = nn.ModuleList([
158
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
159
+ for i in range(decoder_depth)])
160
+
161
+ self.decoder_norm = norm_layer(decoder_embed_dim)
162
+ self.decoder_pred = nn.Linear(decoder_embed_dim, tubelet_size * patch_size * patch_size * in_chans, bias=True) # decoder to patch
163
+ # --------------------------------------------------------------------------
164
+
165
+ self.norm_pix_loss = norm_pix_loss
166
+
167
+ self.initialize_weights()
168
+
169
+ def initialize_weights(self):
170
+ # initialization
171
+ # initialize (and freeze) pos_embed by sin-cos embedding
172
+ pos_embed = get_3d_sincos_pos_embed(self.pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True)
173
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
174
+
175
+ decoder_pos_embed = get_3d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], self.patch_embed.grid_size, cls_token=True)
176
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
177
+
178
+ # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
179
+ w = self.patch_embed.proj.weight.data
180
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
181
+
182
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
183
+ torch.nn.init.normal_(self.cls_token, std=.02)
184
+ torch.nn.init.normal_(self.mask_token, std=.02)
185
+
186
+ # initialize nn.Linear and nn.LayerNorm
187
+ self.apply(self._init_weights)
188
+
189
+ def _init_weights(self, m):
190
+ if isinstance(m, nn.Linear):
191
+ # we use xavier_uniform following official JAX ViT:
192
+ torch.nn.init.xavier_uniform_(m.weight)
193
+ if isinstance(m, nn.Linear) and m.bias is not None:
194
+ nn.init.constant_(m.bias, 0)
195
+ elif isinstance(m, nn.LayerNorm):
196
+ nn.init.constant_(m.bias, 0)
197
+ nn.init.constant_(m.weight, 1.0)
198
+
199
+ def patchify(self, imgs):
200
+ """
201
+ imgs: B, C, T, H, W
202
+ x: B, L, D
203
+ """
204
+ p = self.patch_embed.patch_size[0]
205
+ tub = self.patch_embed.tubelet_size
206
+ x = rearrange(imgs, 'b c (t tub) (h p) (w q) -> b (t h w) (tub p q c)', tub=tub, p=p, q=p)
207
+
208
+ return x
209
+
210
+ def unpatchify(self, x):
211
+ """
212
+ x: B, L, D
213
+ imgs: B, C, T, H, W
214
+ """
215
+ p = self.patch_embed.patch_size[0]
216
+ num_p = self.patch_embed.img_size[0] // p
217
+ tub = self.patch_embed.tubelet_size
218
+ imgs = rearrange(x, 'b (t h w) (tub p q c) -> b c (t tub) (h p) (w q)', h=num_p, w=num_p, tub=tub, p=p, q=p)
219
+ return imgs
220
+
221
+ def random_masking(self, x, mask_ratio):
222
+ """
223
+ Perform per-sample random masking by per-sample shuffling.
224
+ Per-sample shuffling is done by argsort random noise.
225
+ x: [N, L, D], sequence
226
+ """
227
+ N, L, D = x.shape # batch, length, dim
228
+ len_keep = int(L * (1 - mask_ratio))
229
+
230
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
231
+
232
+ # sort noise for each sample
233
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
234
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
235
+
236
+ # keep the first subset
237
+ ids_keep = ids_shuffle[:, :len_keep]
238
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
239
+
240
+ # generate the binary mask: 0 is keep, 1 is remove
241
+ mask = torch.ones([N, L], device=x.device)
242
+ mask[:, :len_keep] = 0
243
+ # unshuffle to get the binary mask
244
+ mask = torch.gather(mask, dim=1, index=ids_restore)
245
+
246
+ return x_masked, mask, ids_restore
247
+
248
+ def forward_encoder(self, x, mask_ratio):
249
+ # embed patches
250
+ x = self.patch_embed(x)
251
+
252
+ # add pos embed w/o cls token
253
+ x = x + self.pos_embed[:, 1:, :]
254
+
255
+ # masking: length -> length * mask_ratio
256
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
257
+
258
+ # append cls token
259
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
260
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
261
+ x = torch.cat((cls_tokens, x), dim=1)
262
+
263
+ # apply Transformer blocks
264
+ for blk in self.blocks:
265
+ x = blk(x)
266
+ x = self.norm(x)
267
+
268
+ return x, mask, ids_restore
269
+
270
+ def forward_decoder(self, x, ids_restore):
271
+ # embed tokens
272
+ x = self.decoder_embed(x)
273
+
274
+ # append mask tokens to sequence
275
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
276
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
277
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
278
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
279
+
280
+ # add pos embed
281
+ x = x + self.decoder_pos_embed
282
+
283
+ # apply Transformer blocks
284
+ for blk in self.decoder_blocks:
285
+ x = blk(x)
286
+ x = self.decoder_norm(x)
287
+
288
+ # predictor projection
289
+ x = self.decoder_pred(x)
290
+
291
+ # remove cls token
292
+ x = x[:, 1:, :]
293
+
294
+ return x
295
+
296
+ def forward_loss(self, imgs, pred, mask):
297
+ """
298
+ imgs: B, C, T, H, W
299
+ target: B, L, D
300
+ pred: B, L, D
301
+ mask: B, L. 0 is keep, 1 is remove,
302
+ """
303
+ target = self.patchify(imgs)
304
+ if self.norm_pix_loss:
305
+ mean = target.mean(dim=-1, keepdim=True)
306
+ var = target.var(dim=-1, keepdim=True)
307
+ target = (target - mean) / (var + 1.e-6)**.5
308
+
309
+ loss = (pred - target) ** 2
310
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
311
+
312
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
313
+ return loss
314
+
315
+ def forward(self, imgs, mask_ratio=0.75):
316
+ latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
317
+ pred = self.forward_decoder(latent, ids_restore)
318
+ loss = self.forward_loss(imgs, pred, mask)
319
+ return loss, pred, mask
Prithvi_100M.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:69f8ac286f649d1bbed520f5c8560a60eba91d688f74e1a0f9aa8203b6fd62ab
3
+ size 453672901
Prithvi_100M_config.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ num_frames: 3
2
+ img_size: 224
3
+ bands: [B02, B03, B04, B05, B06, B07]
4
+ random_cropping: true
5
+ data_loader_num_workers: 1
6
+ depth: 12
7
+ decoder_depth: 8
8
+ patch_size: 16
9
+ embed_dim: 768
10
+ decoder_embed_dim: 512
11
+ num_heads: 12
12
+ decoder_num_heads: 16
13
+ mask_ratio: 0.75
14
+ tubelet_size: 1
15
+ data_mean: [775.2290211032589, 1080.992780391705, 1228.5855250417867, 2497.2022620507532,
16
+ 2204.2139147975554, 1610.8324823273745]
17
+ data_std: [1281.526139861424, 1270.0297974547493, 1399.4802505642526, 1368.3446143747644,
18
+ 1291.6764008585435, 1154.505683480695]
19
+ batch_size: 16
Prithvi_run_inference.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import os
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import rasterio
8
+ import torch
9
+ import yaml
10
+ from einops import rearrange
11
+
12
+ from Prithvi import MaskedAutoencoderViT
13
+
14
+ NO_DATA = -9999
15
+ NO_DATA_FLOAT = 0.0001
16
+ PERCENTILES = (0.1, 99.9)
17
+
18
+
19
+ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
20
+ """ Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
21
+ original range using *data_mean* and *data_std* and then lowest and highest percentiles are
22
+ removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
23
+
24
+ Args:
25
+ orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
26
+ new_img: torch.Tensor representing image with shape = (bands, H, W).
27
+ channels: list of indices representing RGB channels.
28
+ data_mean: list of mean values for each band.
29
+ data_std: list of std values for each band.
30
+
31
+ Returns:
32
+ torch.Tensor with shape (num_channels, height, width) for original image
33
+ torch.Tensor with shape (num_channels, height, width) for the other image
34
+ """
35
+
36
+ stack_c = [], []
37
+
38
+ for c in channels:
39
+ orig_ch = orig_img[c, ...]
40
+ valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
41
+ valid_mask[orig_ch == NO_DATA_FLOAT] = False
42
+
43
+ # Back to original data range
44
+ orig_ch = (orig_ch * data_std[c]) + data_mean[c]
45
+ new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c]
46
+
47
+ # Rescale (enhancing contrast)
48
+ min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES)
49
+
50
+ orig_ch = torch.clamp((orig_ch - min_value) / (max_value - min_value), 0, 1)
51
+ new_ch = torch.clamp((new_ch - min_value) / (max_value - min_value), 0, 1)
52
+
53
+ # No data as zeros
54
+ orig_ch[~valid_mask] = 0
55
+ new_ch[~valid_mask] = 0
56
+
57
+ stack_c[0].append(orig_ch)
58
+ stack_c[1].append(new_ch)
59
+
60
+ # Channels first
61
+ stack_orig = torch.stack(stack_c[0], dim=0)
62
+ stack_rec = torch.stack(stack_c[1], dim=0)
63
+
64
+ return stack_orig, stack_rec
65
+
66
+
67
+ def read_geotiff(file_path: str):
68
+ """ Read all bands from *file_path* and return image + meta info.
69
+
70
+ Args:
71
+ file_path: path to image file.
72
+
73
+ Returns:
74
+ np.ndarray with shape (bands, height, width)
75
+ meta info dict
76
+ """
77
+
78
+ with rasterio.open(file_path) as src:
79
+ img = src.read()
80
+ meta = src.meta
81
+
82
+ return img, meta
83
+
84
+
85
+ def save_geotiff(image, output_path: str, meta: dict):
86
+ """ Save multi-band image in Geotiff file.
87
+
88
+ Args:
89
+ image: np.ndarray with shape (bands, height, width)
90
+ output_path: path where to save the image
91
+ meta: dict with meta info.
92
+ """
93
+
94
+ with rasterio.open(output_path, "w", **meta) as dest:
95
+ for i in range(image.shape[0]):
96
+ dest.write(image[i, :, :], i + 1)
97
+
98
+ return
99
+
100
+
101
+ def _convert_np_uint8(float_image: torch.Tensor):
102
+
103
+ image = float_image.numpy() * 255.0
104
+ image = image.astype(dtype=np.uint8)
105
+
106
+ return image
107
+
108
+
109
+ def load_example(file_paths: List[str], mean: List[float], std: List[float]):
110
+ """ Build an input example by loading images in *file_paths*.
111
+
112
+ Args:
113
+ file_paths: list of file paths .
114
+ mean: list containing mean values for each band in the images in *file_paths*.
115
+ std: list containing std values for each band in the images in *file_paths*.
116
+
117
+ Returns:
118
+ np.array containing created example
119
+ list of meta info for each image in *file_paths*
120
+ """
121
+
122
+ imgs = []
123
+ metas = []
124
+
125
+ for file in file_paths:
126
+ img, meta = read_geotiff(file)
127
+
128
+ # Rescaling (don't normalize on nodata)
129
+ img = np.moveaxis(img, 0, -1) # channels last for rescaling
130
+ img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
131
+
132
+ imgs.append(img)
133
+ metas.append(meta)
134
+
135
+ imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
136
+ imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, H, W
137
+ imgs = np.expand_dims(imgs, axis=0) # add batch dim
138
+
139
+ return imgs, metas
140
+
141
+
142
+ def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
143
+ """ Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
144
+
145
+ Args:
146
+ model: MAE model to run.
147
+ input_data: torch.Tensor with shape (B, C, T, H, W).
148
+ mask_ratio: mask ratio to use.
149
+ device: device where model should run.
150
+
151
+ Returns:
152
+ 3 torch.Tensor with shape (B, C, T, H, W).
153
+ """
154
+
155
+ with torch.no_grad():
156
+ x = input_data.to(device)
157
+
158
+ _, pred, mask = model(x, mask_ratio)
159
+
160
+ # Create mask and prediction images (un-patchify)
161
+ mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
162
+ pred_img = model.unpatchify(pred).detach().cpu()
163
+
164
+ # Mix visible and predicted patches
165
+ rec_img = input_data.clone()
166
+ rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove
167
+
168
+ # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
169
+ mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
170
+
171
+ return rec_img, mask_img
172
+
173
+
174
+ def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
175
+ """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
176
+
177
+ Args:
178
+ input_img: input torch.Tensor with shape (C, T, H, W).
179
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
180
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
181
+ channels: list of indices representing RGB channels.
182
+ mean: list of mean values for each band.
183
+ std: list of std values for each band.
184
+ output_dir: directory where to save outputs.
185
+ meta_data: list of dicts with geotiff meta info.
186
+ """
187
+
188
+ for t in range(input_img.shape[1]):
189
+ rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
190
+ new_img=rec_img[:, t, :, :],
191
+ channels=channels, data_mean=mean,
192
+ data_std=std)
193
+
194
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
195
+
196
+ # Saving images
197
+
198
+ save_geotiff(image=_convert_np_uint8(rgb_orig),
199
+ output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
200
+ meta=meta_data[t])
201
+
202
+ save_geotiff(image=_convert_np_uint8(rgb_pred),
203
+ output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
204
+ meta=meta_data[t])
205
+
206
+ save_geotiff(image=_convert_np_uint8(rgb_mask),
207
+ output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
208
+ meta=meta_data[t])
209
+
210
+
211
+ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
212
+ """ Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
213
+
214
+ Args:
215
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
216
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
217
+ mean: list of mean values for each band.
218
+ std: list of std values for each band.
219
+ output_dir: directory where to save outputs.
220
+ meta_data: list of dicts with geotiff meta info.
221
+ """
222
+
223
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
224
+ std = torch.tensor(np.asarray(std)[:, None, None])
225
+
226
+ for t in range(rec_img.shape[1]):
227
+
228
+ # Back to original data range
229
+ rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
230
+
231
+ mask_img_t = mask_img[:, t, :, :].to(torch.int16)
232
+
233
+ # Saving images
234
+
235
+ save_geotiff(image=rec_img_t,
236
+ output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
237
+ meta=meta_data[t])
238
+
239
+ save_geotiff(image=mask_img_t,
240
+ output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
241
+ meta=meta_data[t])
242
+
243
+
244
+ def main(data_files: List[str], yaml_file_path: str, checkpoint: str, output_dir: str,
245
+ mask_ratio: float, rgb_outputs: bool):
246
+
247
+ os.makedirs(output_dir, exist_ok=True)
248
+
249
+ # Get parameters --------
250
+
251
+ with open(yaml_file_path, 'r') as f:
252
+ params = yaml.safe_load(f)
253
+
254
+ # data related
255
+ num_frames = len(data_files)
256
+ img_size = params['img_size']
257
+ bands = params['bands']
258
+ mean = params['data_mean']
259
+ std = params['data_std']
260
+
261
+ # model related
262
+ depth = params['depth']
263
+ patch_size = params['patch_size']
264
+ embed_dim = params['embed_dim']
265
+ num_heads = params['num_heads']
266
+ tubelet_size = params['tubelet_size']
267
+ decoder_embed_dim = params['decoder_embed_dim']
268
+ decoder_num_heads = params['decoder_num_heads']
269
+ decoder_depth = params['decoder_depth']
270
+
271
+ batch_size = params['batch_size']
272
+
273
+ mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
274
+
275
+ print(f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n")
276
+ if len(data_files) != 3:
277
+ print("The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary")
278
+
279
+ if torch.cuda.is_available():
280
+ device = torch.device('cuda')
281
+ else:
282
+ device = torch.device('cpu')
283
+
284
+ print(f"Using {device} device.\n")
285
+
286
+ # Loading data ---------------------------------------------------------------------------------
287
+
288
+ input_data, meta_data = load_example(file_paths=data_files, mean=mean, std=std)
289
+
290
+ # Create model and load checkpoint -------------------------------------------------------------
291
+
292
+ model = MaskedAutoencoderViT(
293
+ img_size=img_size,
294
+ patch_size=patch_size,
295
+ num_frames=num_frames,
296
+ tubelet_size=tubelet_size,
297
+ in_chans=len(bands),
298
+ embed_dim=embed_dim,
299
+ depth=depth,
300
+ num_heads=num_heads,
301
+ decoder_embed_dim=decoder_embed_dim,
302
+ decoder_depth=decoder_depth,
303
+ decoder_num_heads=decoder_num_heads,
304
+ mlp_ratio=4.,
305
+ norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
306
+ norm_pix_loss=False)
307
+
308
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
309
+ print(f"\n--> Model has {total_params:,} parameters.\n")
310
+
311
+ model.to(device)
312
+
313
+ state_dict = torch.load(checkpoint, map_location=device)
314
+ # discard fixed pos_embedding weight
315
+ del state_dict['pos_embed']
316
+ del state_dict['decoder_pos_embed']
317
+ model.load_state_dict(state_dict, strict=False)
318
+ print(f"Loaded checkpoint from {checkpoint}")
319
+
320
+ # Running model --------------------------------------------------------------------------------
321
+
322
+ model.eval()
323
+ channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
324
+
325
+ # Reflect pad if not divisible by img_size
326
+ original_h, original_w = input_data.shape[-2:]
327
+ pad_h = img_size - (original_h % img_size)
328
+ pad_w = img_size - (original_w % img_size)
329
+ input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect')
330
+
331
+ # Build sliding window
332
+ batch = torch.tensor(input_data, device='cpu')
333
+ windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
334
+ h1, w1 = windows.shape[3:5]
335
+ windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size)
336
+
337
+ # Split into batches if number of windows > batch_size
338
+ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
339
+ windows = torch.tensor_split(windows, num_batches, dim=0)
340
+
341
+ # Run model
342
+ rec_imgs = []
343
+ mask_imgs = []
344
+ for x in windows:
345
+ rec_img, mask_img = run_model(model, x, mask_ratio, device)
346
+ rec_imgs.append(rec_img)
347
+ mask_imgs.append(mask_img)
348
+
349
+ rec_imgs = torch.concat(rec_imgs, dim=0)
350
+ mask_imgs = torch.concat(mask_imgs, dim=0)
351
+
352
+ # Build images from patches
353
+ rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
354
+ h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
355
+ mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
356
+ h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
357
+
358
+ # Cut padded images back to original size
359
+ rec_imgs_full = rec_imgs[..., :original_h, :original_w]
360
+ mask_imgs_full = mask_imgs[..., :original_h, :original_w]
361
+ batch_full = batch[..., :original_h, :original_w]
362
+
363
+ # Build output images
364
+ if rgb_outputs:
365
+ for d in meta_data:
366
+ d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
367
+
368
+ save_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
369
+ channels, mean, std, output_dir, meta_data)
370
+ else:
371
+ for d in meta_data:
372
+ d.update(compress='lzw', nodata=0)
373
+
374
+ save_imgs(rec_imgs_full[0, ...], mask_imgs_full[0, ...], mean, std, output_dir, meta_data)
375
+
376
+ print("Done!")
377
+
378
+
379
+ if __name__ == "__main__":
380
+ parser = argparse.ArgumentParser('MAE run inference', add_help=False)
381
+
382
+ parser.add_argument('--data_files', required=True, type=str, nargs='+',
383
+ help='Path to the data files. Assumes multi-band files.')
384
+ parser.add_argument('--yaml_file_path', type=str, required=True,
385
+ help='Path to yaml file containing model training parameters.')
386
+ parser.add_argument('--checkpoint', required=True, type=str,
387
+ help='Path to a checkpoint file to load from.')
388
+ parser.add_argument('--output_dir', required=True, type=str,
389
+ help='Path to the directory where to save outputs.')
390
+ parser.add_argument('--mask_ratio', default=None, type=float,
391
+ help='Masking ratio (percentage of removed patches). '
392
+ 'If None (default) use same value used for pretraining.')
393
+ parser.add_argument('--rgb_outputs', action='store_true',
394
+ help='If present, output files will only contain RGB channels. '
395
+ 'Otherwise, all bands will be saved.')
396
+ args = parser.parse_args()
397
+
398
+ main(**vars(args))
399
+
README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - Pytorch
5
+ - Geospatial
6
+ - Temporal ViT
7
+ - Vit
8
+ ---
9
+
10
+
11
+ ### Code
12
+ The model follows the [original repo](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M)
13
+ I made simple modification by using original repo source code to visualize.
14
+
15
+ ### Data
16
+ Area: South Korea Jeollanam-do
17
+ Sourced the image from this [link](https://search.earthdata.nasa.gov/search/granules?p=C2021957657-LPCLOUD&pg[0][v]=f&pg[0][qt]=2009-01-01T00%3A00%3A00.000Z%2C&pg[0][dnf]=DAY&pg[0][gsk]=-start_date&q=HLSL30&sb[0]=126.57129%2C34.87923%2C126.97998%2C35.09012&tl=1696429462!3!!&lat=33.70166015625&long=125.0771484375&zoom=7)
18
+ Google map location [link](https://www.google.co.kr/maps/place/34%C2%B052'45.2%22N+126%C2%B034'16.6%22E/data=!4m4!3m3!8m2!3d34.87923!4d126.57129?hl=ko&entry=ttu)
19
+
20
+ ### Usecase
21
+ Here's a sample video:
22
+
23
+ ![Sample Video](streamlit-testing.webm)
24
+
25
+
26
+
27
+ ### Citation
28
+
29
+ Please cite original repository.
30
+ If this model helped your research, please cite `Prithvi-100M` in your publications. Here is an example BibTeX entry:
31
+
32
+ ```
33
+ @misc{Prithvi-100M,
34
+ author = {Jakubik, Johannes and Chu, Linsong and Fraccaro, Paolo and Gomes, Carlos and Nyirjesy, Gabby and Bangalore, Ranjini and Lambhate, Devyani and Das, Kamal and Oliveira Borges, Dario and Kimura, Daiki and Simumba, Naomi and Szwarcman, Daniela and Muszynski, Michal and Weldemariam, Kommy and Zadrozny, Bianca and Ganti, Raghu and Costa, Carlos and Edwards, Blair & Watson, Campbell and Mukkavilli, Karthik and Schmude, Johannes & Hamann, Hendrik and Robert, Parkin and Roy, Sujit and Phillips, Christopher and Ankur, Kumar and Ramasubramanian, Muthukumaran and Gurung, Iksha and Leong, Wei Ji and Avery, Ryan and Ramachandran, Rahul and Maskey, Manil and Olofossen, Pontus and Fancher, Elizabeth and Lee, Tsengdar and Murphy, Kevin and Duffy, Dan and Little, Mike and Alemohammad, Hamed and Cecil, Michael and Li, Steve and Khallaghi, Sam and Godwin, Denys and Ahmadi, Maryam and Kordi, Fatemeh and Saux, Bertrand and Pastick, Neal and Doucette, Peter and Fleckenstein, Rylie and Luanga, Dalton and Corvin, Alex and Granger, Erwan},
35
+ doi = {10.57967/hf/0952},
36
+ month = aug,
37
+ title = {{Prithvi-100M}},
38
+ repository-code = {https://github.com/NASA-IMPACT/hls-foundation-os},
39
+ year = {2023}
40
+ }
41
+ ```
app.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import hf_hub_download
2
+ import os
3
+ import functools
4
+ from typing import List
5
+ import numpy as np
6
+ import torch
7
+ import yaml
8
+ from einops import rearrange
9
+ from Prithvi import MaskedAutoencoderViT
10
+ from functools import partial
11
+
12
+ import rasterio
13
+ from rasterio.merge import merge
14
+ from rasterio import Affine
15
+ from rasterio.warp import calculate_default_transform, reproject, Resampling
16
+
17
+ import streamlit as st
18
+ from streamlit_image_comparison import image_comparison
19
+
20
+ NO_DATA = -9999
21
+ NO_DATA_FLOAT = 0.0001
22
+ PERCENTILES = (0.1, 99.9)
23
+
24
+ TOKEN = "JONATHAN_TOKEN"
25
+ yaml_file_path=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename="Prithvi_100M_config.yaml", token=TOKEN)
26
+ checkpoint=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename='Prithvi_100M.pt', token=TOKEN)
27
+ model_def=hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-100M", filename='Prithvi.py', token=TOKEN)
28
+
29
+
30
+
31
+ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
32
+ """ Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
33
+ original range using *data_mean* and *data_std* and then lowest and highest percentiles are
34
+ removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
35
+ Args:
36
+ orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
37
+ new_img: torch.Tensor representing image with shape = (bands, H, W).
38
+ channels: list of indices representing RGB channels.
39
+ data_mean: list of mean values for each band.
40
+ data_std: list of std values for each band.
41
+ Returns:
42
+ torch.Tensor with shape (num_channels, height, width) for original image
43
+ torch.Tensor with shape (num_channels, height, width) for the other image
44
+ """
45
+
46
+ stack_c = [], []
47
+
48
+ for c in channels:
49
+ orig_ch = orig_img[c, ...]
50
+ valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
51
+ valid_mask[orig_ch == NO_DATA_FLOAT] = False
52
+
53
+ # Back to original data range
54
+ orig_ch = (orig_ch * data_std[c]) + data_mean[c]
55
+ new_ch = (new_img[c, ...] * data_std[c]) + data_mean[c]
56
+
57
+ # Rescale (enhancing contrast)
58
+ min_value, max_value = np.percentile(orig_ch[valid_mask], PERCENTILES)
59
+
60
+ orig_ch = torch.clamp((orig_ch - min_value) / (max_value - min_value), 0, 1)
61
+ new_ch = torch.clamp((new_ch - min_value) / (max_value - min_value), 0, 1)
62
+
63
+ # No data as zeros
64
+ orig_ch[~valid_mask] = 0
65
+ new_ch[~valid_mask] = 0
66
+
67
+ stack_c[0].append(orig_ch)
68
+ stack_c[1].append(new_ch)
69
+
70
+ # Channels first
71
+ stack_orig = torch.stack(stack_c[0], dim=0)
72
+ stack_rec = torch.stack(stack_c[1], dim=0)
73
+
74
+ return stack_orig, stack_rec
75
+
76
+
77
+ def read_geotiff(file_path: str):
78
+ """ Read all bands from *file_path* and returns image + meta info.
79
+ Args:
80
+ file_path: path to image file.
81
+ Returns:
82
+ np.ndarray with shape (bands, height, width)
83
+ meta info dict
84
+ """
85
+
86
+ with rasterio.open(file_path) as src:
87
+ img = src.read()
88
+ meta = src.meta
89
+
90
+ return img, meta
91
+
92
+
93
+ def save_geotiff(image, output_path: str, meta: dict):
94
+ """ Save multi-band image in Geotiff file.
95
+ Args:
96
+ image: np.ndarray with shape (bands, height, width)
97
+ output_path: path where to save the image
98
+ meta: dict with meta info.
99
+ """
100
+
101
+ with rasterio.open(output_path, "w", **meta) as dest:
102
+ for i in range(image.shape[0]):
103
+ dest.write(image[i, :, :], i + 1)
104
+
105
+ return
106
+
107
+
108
+ def _convert_np_uint8(float_image: torch.Tensor):
109
+
110
+ image = float_image.numpy() * 255.0
111
+ image = image.astype(dtype=np.uint8)
112
+ image = image.transpose((1, 2, 0))
113
+
114
+ return image
115
+
116
+
117
+ def load_example(file_paths: List[str], mean: List[float], std: List[float]):
118
+ """ Build an input example by loading images in *file_paths*.
119
+ Args:
120
+ file_paths: list of file paths .
121
+ mean: list containing mean values for each band in the images in *file_paths*.
122
+ std: list containing std values for each band in the images in *file_paths*.
123
+ Returns:
124
+ np.array containing created example
125
+ list of meta info for each image in *file_paths*
126
+ """
127
+
128
+ imgs = []
129
+ metas = []
130
+
131
+ for file in file_paths:
132
+ img, meta = read_geotiff(file)
133
+ img = img[:6]*10000 if img[:6].mean() <= 2 else img[:6]
134
+
135
+ # Rescaling (don't normalize on nodata)
136
+ img = np.moveaxis(img, 0, -1) # channels last for rescaling
137
+ img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
138
+
139
+ imgs.append(img)
140
+ metas.append(meta)
141
+
142
+ imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
143
+ imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, H, W
144
+ imgs = np.expand_dims(imgs, axis=0) # add batch dim
145
+
146
+ return imgs, metas
147
+
148
+
149
+ def run_model(model: torch.nn.Module, input_data: torch.Tensor, mask_ratio: float, device: torch.device):
150
+ """ Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
151
+ Args:
152
+ model: MAE model to run.
153
+ input_data: torch.Tensor with shape (B, C, T, H, W).
154
+ mask_ratio: mask ratio to use.
155
+ device: device where model should run.
156
+ Returns:
157
+ 3 torch.Tensor with shape (B, C, T, H, W).
158
+ """
159
+
160
+ with torch.no_grad():
161
+ x = input_data.to(device)
162
+
163
+ _, pred, mask = model(x, mask_ratio)
164
+
165
+ # Create mask and prediction images (un-patchify)
166
+ mask_img = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
167
+ pred_img = model.unpatchify(pred).detach().cpu()
168
+
169
+ # Mix visible and predicted patches
170
+ rec_img = input_data.clone()
171
+ rec_img[mask_img == 1] = pred_img[mask_img == 1] # binary mask: 0 is keep, 1 is remove
172
+
173
+ # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
174
+ mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
175
+
176
+ return rec_img, mask_img
177
+
178
+
179
+ def save_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data):
180
+ """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
181
+ Args:
182
+ input_img: input torch.Tensor with shape (C, T, H, W).
183
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
184
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
185
+ channels: list of indices representing RGB channels.
186
+ mean: list of mean values for each band.
187
+ std: list of std values for each band.
188
+ output_dir: directory where to save outputs.
189
+ meta_data: list of dicts with geotiff meta info.
190
+ """
191
+
192
+ for t in range(input_img.shape[1]):
193
+ rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
194
+ new_img=rec_img[:, t, :, :],
195
+ channels=channels, data_mean=mean,
196
+ data_std=std)
197
+
198
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
199
+
200
+ # Saving images
201
+
202
+ save_geotiff(image=_convert_np_uint8(rgb_orig),
203
+ output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
204
+ meta=meta_data[t])
205
+
206
+ save_geotiff(image=_convert_np_uint8(rgb_pred),
207
+ output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
208
+ meta=meta_data[t])
209
+
210
+ save_geotiff(image=_convert_np_uint8(rgb_mask),
211
+ output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
212
+ meta=meta_data[t])
213
+
214
+
215
+ def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
216
+ """ Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
217
+ Args:
218
+ input_img: input torch.Tensor with shape (C, T, H, W).
219
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
220
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
221
+ channels: list of indices representing RGB channels.
222
+ mean: list of mean values for each band.
223
+ std: list of std values for each band.
224
+ output_dir: directory where to save outputs.
225
+ meta_data: list of dicts with geotiff meta info.
226
+ """
227
+ rgb_orig_list = []
228
+ rgb_mask_list = []
229
+ rgb_pred_list = []
230
+
231
+ for t in range(input_img.shape[1]):
232
+ rgb_orig, rgb_pred = process_channel_group(orig_img=input_img[:, t, :, :],
233
+ new_img=rec_img[:, t, :, :],
234
+ channels=channels, data_mean=mean,
235
+ data_std=std)
236
+
237
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
238
+
239
+ # extract images
240
+ rgb_orig_list.append(_convert_np_uint8(rgb_orig))
241
+ rgb_mask_list.append(_convert_np_uint8(rgb_mask))
242
+ rgb_pred_list.append(_convert_np_uint8(rgb_pred))
243
+
244
+ outputs = rgb_orig_list + rgb_mask_list + rgb_pred_list
245
+
246
+ return outputs
247
+
248
+
249
+ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str, checkpoint: str):
250
+
251
+
252
+ try:
253
+ data_files = [x.name for x in data_files]
254
+ print('Path extracted from example')
255
+ except:
256
+ print('Files submitted through UI')
257
+
258
+ # Get parameters --------
259
+ print('This is the printout', data_files)
260
+
261
+ with open(yaml_file_path, 'r') as f:
262
+ params = yaml.safe_load(f)
263
+
264
+ # data related
265
+ num_frames = params['num_frames']
266
+ img_size = params['img_size']
267
+ bands = params['bands']
268
+ mean = params['data_mean']
269
+ std = params['data_std']
270
+
271
+ # model related
272
+ depth = params['depth']
273
+ patch_size = params['patch_size']
274
+ embed_dim = params['embed_dim']
275
+ num_heads = params['num_heads']
276
+ tubelet_size = params['tubelet_size']
277
+ decoder_embed_dim = params['decoder_embed_dim']
278
+ decoder_num_heads = params['decoder_num_heads']
279
+ decoder_depth = params['decoder_depth']
280
+
281
+ batch_size = params['batch_size']
282
+
283
+ mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
284
+
285
+ # We must have *num_frames* files to build one example!
286
+ assert len(data_files) == num_frames, "File list must be equal to expected number of frames."
287
+
288
+ if torch.cuda.is_available():
289
+ device = torch.device('cuda')
290
+ else:
291
+ device = torch.device('cpu')
292
+
293
+ print(f"Using {device} device.\n")
294
+
295
+ # Loading data ---------------------------------------------------------------------------------
296
+
297
+ input_data, meta_data = load_example(file_paths=data_files, mean=mean, std=std)
298
+
299
+ # Create model and load checkpoint -------------------------------------------------------------
300
+
301
+ model = MaskedAutoencoderViT(
302
+ img_size=img_size,
303
+ patch_size=patch_size,
304
+ num_frames=num_frames,
305
+ tubelet_size=tubelet_size,
306
+ in_chans=len(bands),
307
+ embed_dim=embed_dim,
308
+ depth=depth,
309
+ num_heads=num_heads,
310
+ decoder_embed_dim=decoder_embed_dim,
311
+ decoder_depth=decoder_depth,
312
+ decoder_num_heads=decoder_num_heads,
313
+ mlp_ratio=4.,
314
+ norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
315
+ norm_pix_loss=False)
316
+
317
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
318
+ print(f"\n--> Model has {total_params:,} parameters.\n")
319
+
320
+ model.to(device)
321
+
322
+ state_dict = torch.load(checkpoint, map_location=device)
323
+ model.load_state_dict(state_dict)
324
+ print(f"Loaded checkpoint from {checkpoint}")
325
+
326
+ # Running model --------------------------------------------------------------------------------
327
+
328
+ model.eval()
329
+ channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
330
+
331
+ # Reflect pad if not divisible by img_size
332
+ original_h, original_w = input_data.shape[-2:]
333
+ pad_h = img_size - (original_h % img_size)
334
+ pad_w = img_size - (original_w % img_size)
335
+ input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect')
336
+
337
+ # Build sliding window
338
+ batch = torch.tensor(input_data, device='cpu')
339
+ windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
340
+ h1, w1 = windows.shape[3:5]
341
+ windows = rearrange(windows, 'b c t h1 w1 h w -> (b h1 w1) c t h w', h=img_size, w=img_size)
342
+
343
+ # Split into batches if number of windows > batch_size
344
+ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
345
+ windows = torch.tensor_split(windows, num_batches, dim=0)
346
+
347
+ # Run model
348
+ rec_imgs = []
349
+ mask_imgs = []
350
+ for x in windows:
351
+ rec_img, mask_img = run_model(model, x, mask_ratio, device)
352
+ rec_imgs.append(rec_img)
353
+ mask_imgs.append(mask_img)
354
+
355
+ rec_imgs = torch.concat(rec_imgs, dim=0)
356
+ mask_imgs = torch.concat(mask_imgs, dim=0)
357
+
358
+ # Build images from patches
359
+ rec_imgs = rearrange(rec_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
360
+ h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
361
+ mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
362
+ h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
363
+
364
+ # Cut padded images back to original size
365
+ rec_imgs_full = rec_imgs[..., :original_h, :original_w]
366
+ mask_imgs_full = mask_imgs[..., :original_h, :original_w]
367
+ batch_full = batch[..., :original_h, :original_w]
368
+
369
+ # Build RGB images
370
+ for d in meta_data:
371
+ d.update(count=3, dtype='uint8', compress='lzw', nodata=0)
372
+
373
+ # save_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
374
+ # channels, mean, std, output_dir, meta_data)
375
+
376
+ outputs = extract_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
377
+ channels, mean, std)
378
+
379
+
380
+ print("Done!")
381
+
382
+ return outputs
383
+
384
+ # partial function prep
385
+ func = partial(predict_on_images, yaml_file_path=yaml_file_path,checkpoint=checkpoint,mask_ratio=0.75)
386
+
387
+
388
+
389
+
390
+ ## South Korea rural area HSL landset merger (from B02 channel to B07 channel)
391
+
392
+ def raw_NASA_tif_file_merger(PSD_NAME,tif_files):
393
+ src_files_to_mosaic = []
394
+ for tif_file in tif_files:
395
+ src = rasterio.open(tif_file)
396
+ src_files_to_mosaic.append(src)
397
+ mosaic, out_trans = merge(src_files_to_mosaic)
398
+ out_meta = src.meta.copy()
399
+ out_meta.update({"driver": "GTiff",
400
+ "height": mosaic.shape[1],
401
+ "width": mosaic.shape[2],
402
+ "transform": out_trans})
403
+
404
+
405
+ with rasterio.open(PSD_NAME, "w", **out_meta) as dest:
406
+ dest.write(mosaic)
407
+
408
+ # raw_NASA_tif_file_merger("third.tif","./1/*.tif")
409
+
410
+
411
+
412
+ # streamlit area
413
+ def main_loop():
414
+ st.title("HuggingFace Inference Demo")
415
+ st.subheader("Be sure to set the parameter")
416
+
417
+ [out1_orig_t1,out2_orig_t2,out3_orig_t3,out4_masked_t1,out5_masked_t2,out6_masked_t3,out7_pred_t1,out8_pred_t2,out9_pred_t3]=func(["first.tif","second.tif","third.tif"])
418
+
419
+
420
+ st.markdown("### first original image and masked image comparison")
421
+ image_comparison(
422
+ img1=out1_orig_t1,
423
+ img2=out4_masked_t1,
424
+ label1="original-1",
425
+ label2="masked-1",
426
+ width=1024,
427
+ )
428
+
429
+
430
+ st.markdown("### second original image and masked image comparison")
431
+ image_comparison(
432
+ img1=out2_orig_t2,
433
+ img2=out5_masked_t2,
434
+ label1="original-2",
435
+ label2="masked-2",
436
+ width=1024,
437
+ )
438
+
439
+
440
+ st.markdown("### thrid original image and masked image comparison")
441
+ image_comparison(
442
+ img1=out3_orig_t3,
443
+ img2=out6_masked_t3,
444
+ label1="original-1",
445
+ label2="masked-1",
446
+ width=1024,
447
+ )
448
+
449
+
450
+
451
+ st.markdown("### first original image and encoded image comparison")
452
+ image_comparison(
453
+ img1=out1_orig_t1,
454
+ img2=out7_pred_t1,
455
+ label1="original-1",
456
+ label2="masked-1",
457
+ width=1024,
458
+ )
459
+
460
+
461
+ st.markdown("### second original image and encoded image comparison")
462
+ image_comparison(
463
+ img1=out2_orig_t2,
464
+ img2=out8_pred_t2,
465
+ label1="original-2",
466
+ label2="masked-2",
467
+ width=1024,
468
+ )
469
+
470
+
471
+ st.markdown("### thrid original image and encoded image comparison")
472
+ image_comparison(
473
+ img1=out3_orig_t3,
474
+ img2=out9_pred_t3,
475
+ label1="original-1",
476
+ label2="masked-1",
477
+ width=1024,
478
+ )
479
+
480
+ if __name__ == '__main__':
481
+ main_loop()
482
+
483
+
first.tif ADDED

Git LFS Details

  • SHA256: 6501432e7602e31486e5367a74992340d84011ccc77774893914f8fbf1ad7310
  • Pointer size: 133 Bytes
  • Size of remote file: 26.8 MB
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ einops
5
+ rasterio
second.tif ADDED

Git LFS Details

  • SHA256: ca9f32a906ed42c937c332bed3c391414484277eb9c38928b10ae82953099d38
  • Pointer size: 133 Bytes
  • Size of remote file: 26.8 MB
streamlit-testing.webm ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:590a810b981ac32d78a04f65ac78b4f2c15d77e465b95b659c1e1be83181e94a
3
+ size 10384274
temp/1ab7b057-9543-4240-92a1-f85bba853af6.jpg ADDED

Git LFS Details

  • SHA256: 1d2e7b6a2d13de3df0bc79da2862f4b7404a33b4d96cc7521af00afc58c123c1
  • Pointer size: 132 Bytes
  • Size of remote file: 5.68 MB
temp/771eca76-489a-41c7-bcfb-28b841e78dd7.jpg ADDED

Git LFS Details

  • SHA256: 3fa98a890ad524268288b68b6beeb5d597dfb25a5b1ed23bf69efbea6f6f7f6b
  • Pointer size: 132 Bytes
  • Size of remote file: 4.55 MB
third.tif ADDED

Git LFS Details

  • SHA256: ecf96e0671c573206a2861618007b3550cf95a0e89a553a5478de2c4371b396c
  • Pointer size: 133 Bytes
  • Size of remote file: 26.8 MB