Paolo-Fraccaro commited on
Commit
f62a54e
1 Parent(s): ff3ecff

add padding

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -47,7 +47,7 @@ def process_channel_group(orig_img, new_img, channels, data_mean, data_std):
47
  for c in channels:
48
  orig_ch = orig_img[c, ...]
49
  valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
50
- valid_mask[orig_ch == 0.0001] = False
51
 
52
  # Back to original data range
53
  orig_ch = (orig_ch * data_std[c]) + data_mean[c]
@@ -138,8 +138,8 @@ def load_example(file_paths: List[str], mean: List[float], std: List[float]):
138
  imgs.append(img)
139
  metas.append(meta)
140
 
141
- imgs = np.stack(imgs, axis=0) # num_frames, img_size, img_size, C
142
- imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, img_size, img_size
143
  imgs = np.expand_dims(imgs, axis=0) # add batch dim
144
 
145
  return imgs, metas
@@ -308,7 +308,7 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
308
  norm_pix_loss=False)
309
 
310
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
311
- print(f"\n--> model has {total_params / 1e6} Million params.\n")
312
 
313
  model.to(device)
314
 
@@ -320,6 +320,12 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
320
 
321
  model.eval()
322
  channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
 
 
 
 
 
 
323
 
324
  # Build sliding window
325
  batch = torch.tensor(input_data, device='cpu')
@@ -348,13 +354,10 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
348
  mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
349
  h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
350
 
351
- # Mix original image with patches
352
- h, w = rec_imgs.shape[-2:]
353
- rec_imgs_full = batch.clone()
354
- rec_imgs_full[..., :h, :w] = rec_imgs
355
-
356
- mask_imgs_full = torch.ones_like(batch)
357
- mask_imgs_full[..., :h, :w] = mask_imgs
358
 
359
  # Build RGB images
360
  for d in meta_data:
@@ -363,7 +366,7 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
363
  # save_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
364
  # channels, mean, std, output_dir, meta_data)
365
 
366
- outputs = extract_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
367
  channels, mean, std)
368
 
369
 
 
47
  for c in channels:
48
  orig_ch = orig_img[c, ...]
49
  valid_mask = torch.ones_like(orig_ch, dtype=torch.bool)
50
+ valid_mask[orig_ch == NO_DATA_FLOAT] = False
51
 
52
  # Back to original data range
53
  orig_ch = (orig_ch * data_std[c]) + data_mean[c]
 
138
  imgs.append(img)
139
  metas.append(meta)
140
 
141
+ imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
142
+ imgs = np.moveaxis(imgs, -1, 0).astype('float32') # C, num_frames, H, W
143
  imgs = np.expand_dims(imgs, axis=0) # add batch dim
144
 
145
  return imgs, metas
 
308
  norm_pix_loss=False)
309
 
310
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
311
+ print(f"\n--> Model has {total_params:,} parameters.\n")
312
 
313
  model.to(device)
314
 
 
320
 
321
  model.eval()
322
  channels = [bands.index(b) for b in ['B04', 'B03', 'B02']] # BGR -> RGB
323
+
324
+ # Reflect pad if not divisible by img_size
325
+ original_h, original_w = input_data.shape[-2:]
326
+ pad_h = img_size - (original_h % img_size)
327
+ pad_w = img_size - (original_w % img_size)
328
+ input_data = np.pad(input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode='reflect')
329
 
330
  # Build sliding window
331
  batch = torch.tensor(input_data, device='cpu')
 
354
  mask_imgs = rearrange(mask_imgs, '(b h1 w1) c t h w -> b c t (h1 h) (w1 w)',
355
  h=img_size, w=img_size, b=1, c=len(bands), t=num_frames, h1=h1, w1=w1)
356
 
357
+ # Cut padded images back to original size
358
+ rec_imgs_full = rec_imgs[..., :original_h, :original_w]
359
+ mask_imgs_full = mask_imgs[..., :original_h, :original_w]
360
+ batch_full = batch[..., :original_h, :original_w]
 
 
 
361
 
362
  # Build RGB images
363
  for d in meta_data:
 
366
  # save_rgb_imgs(batch[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
367
  # channels, mean, std, output_dir, meta_data)
368
 
369
+ outputs = extract_rgb_imgs(batch_full[0, ...], rec_imgs_full[0, ...], mask_imgs_full[0, ...],
370
  channels, mean, std)
371
 
372