vict0rsch commited on
Commit
c1c4fcb
1 Parent(s): 2b1d700

fix fire inference

Browse files
Files changed (3) hide show
  1. app.py +4 -5
  2. climategan/fire.py +6 -5
  3. requirements.txt +1 -0
app.py CHANGED
@@ -2,9 +2,9 @@
2
  # thank you @NimaBoscarino
3
 
4
  import os
 
5
  from textwrap import dedent
6
  from urllib import parse
7
- from requests import get
8
 
9
  import googlemaps
10
  import gradio as gr
@@ -20,8 +20,8 @@ from gradio.components import (
20
  Row,
21
  Textbox,
22
  )
 
23
  from skimage import io
24
- from datetime import datetime
25
 
26
  from climategan_wrapper import ClimateGAN
27
 
@@ -76,9 +76,9 @@ TEXTS = [
76
    |  
77
  Read the original
78
  <a
79
- href='https://openreview.net/forum?id=EZNOb_uNpJk'
80
  target='_blank'>
81
- ICLR 2021 ClimateGAN paper
82
  </a>
83
  </p>
84
  """
@@ -217,7 +217,6 @@ def predict(cg: ClimateGAN, api_key):
217
 
218
 
219
  if __name__ == "__main__":
220
-
221
  ip = get("https://api.ipify.org").content.decode("utf8")
222
  print("My public IP address is: {}".format(ip))
223
 
2
  # thank you @NimaBoscarino
3
 
4
  import os
5
+ from datetime import datetime
6
  from textwrap import dedent
7
  from urllib import parse
 
8
 
9
  import googlemaps
10
  import gradio as gr
20
  Row,
21
  Textbox,
22
  )
23
+ from requests import get
24
  from skimage import io
 
25
 
26
  from climategan_wrapper import ClimateGAN
27
 
76
  &nbsp;&nbsp;|&nbsp;&nbsp;
77
  Read the original
78
  <a
79
+ href='https://arxiv.org/abs/2110.02871'
80
  target='_blank'>
81
+ ICLR 2022 ClimateGAN paper
82
  </a>
83
  </p>
84
  """
217
 
218
 
219
  if __name__ == "__main__":
 
220
  ip = get("https://api.ipify.org").content.decode("utf8")
221
  print("My public IP address is: {}".format(ip))
222
 
climategan/fire.py CHANGED
@@ -1,7 +1,8 @@
1
- import torch
2
- import torch.nn.functional as F
3
  import random
 
4
  import kornia
 
 
5
  from torchvision.transforms.functional import adjust_brightness, adjust_contrast
6
 
7
  from climategan.tutils import normalize, retrieve_sky_mask
@@ -105,9 +106,9 @@ def add_fire(x, seg_preds, fire_opts):
105
  kernel_size = (fire_opts.get("kernel_size", 301), fire_opts.get("kernel_size", 301))
106
  sigma = (fire_opts.get("kernel_sigma", 150.5), fire_opts.get("kernel_sigma", 150.5))
107
  border_type = "reflect"
108
- kernel = torch.unsqueeze(
109
- kornia.filters.kernels.get_gaussian_kernel2d(kernel_size, sigma), dim=0
110
- ).to(x.device)
111
  sky_mask = filter2d(sky_mask, kernel, border_type)
112
 
113
  filter_ = torch.ones(wildfire_tens.shape, device=x.device)
 
 
1
  import random
2
+
3
  import kornia
4
+ import torch
5
+ import torch.nn.functional as F
6
  from torchvision.transforms.functional import adjust_brightness, adjust_contrast
7
 
8
  from climategan.tutils import normalize, retrieve_sky_mask
106
  kernel_size = (fire_opts.get("kernel_size", 301), fire_opts.get("kernel_size", 301))
107
  sigma = (fire_opts.get("kernel_sigma", 150.5), fire_opts.get("kernel_sigma", 150.5))
108
  border_type = "reflect"
109
+ kernel = kornia.filters.kernels.get_gaussian_kernel2d(kernel_size, sigma)
110
+ if kernel.ndim == 2:
111
+ kernel = kernel.unsqueeze(0)
112
  sky_mask = filter2d(sky_mask, kernel, border_type)
113
 
114
  filter_ = torch.ones(wildfire_tens.shape, device=x.device)
requirements.txt CHANGED
@@ -2,6 +2,7 @@ gradio==3.44.1
2
  torch
3
  torch-optimizer
4
  torchvision
 
5
  addict
6
  aiohttp
7
  aiosignal
2
  torch
3
  torch-optimizer
4
  torchvision
5
+ accelerate
6
  addict
7
  aiohttp
8
  aiosignal