anonymous commited on
Commit
2a8678c
1 Parent(s): 4fcfd85
Files changed (1) hide show
  1. src/img_util.py +3 -1
src/img_util.py CHANGED
@@ -2,6 +2,8 @@ import einops
2
  import torch
3
  import torch.nn.functional as F
4
 
 
 
5
 
6
  @torch.no_grad()
7
  def find_flat_region(mask):
@@ -18,6 +20,6 @@ def find_flat_region(mask):
18
 
19
 
20
  def numpy2tensor(img):
21
- x0 = torch.from_numpy(img.copy()).float().cuda() / 255.0 * 2.0 - 1.
22
  x0 = torch.stack([x0], dim=0)
23
  return einops.rearrange(x0, 'b h w c -> b c h w').clone()
 
2
  import torch
3
  import torch.nn.functional as F
4
 
5
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
6
+
7
 
8
  @torch.no_grad()
9
  def find_flat_region(mask):
 
20
 
21
 
22
  def numpy2tensor(img):
23
+ x0 = torch.from_numpy(img.copy()).float().to(device) / 255.0 * 2.0 - 1.
24
  x0 = torch.stack([x0], dim=0)
25
  return einops.rearrange(x0, 'b h w c -> b c h w').clone()