Anonymous-sub commited on
Commit
deab087
1 Parent(s): 9c1dc83

Update flow/flow_utils.py

Browse files
Files changed (1) hide show
  1. flow/flow_utils.py +8 -5
flow/flow_utils.py CHANGED
@@ -12,6 +12,8 @@ sys.path.insert(0, gmflow_dir)
12
  from gmflow.gmflow import GMFlow # noqa: E702 E402 F401
13
  from utils.utils import InputPadder # noqa: E702 E402
14
 
 
 
15
 
16
  def coords_grid(b, h, w, homogeneous=False, device=None):
17
  y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
@@ -27,7 +29,7 @@ def coords_grid(b, h, w, homogeneous=False, device=None):
27
  grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
28
 
29
  if device is not None:
30
- grid = grid.to(device)
31
 
32
  return grid
33
 
@@ -117,7 +119,8 @@ def get_warped_and_mask(flow_model,
117
  if image3 is None:
118
  image3 = image1
119
  padder = InputPadder(image1.shape, padding_factor=8)
120
- image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
 
121
  results_dict = flow_model(image1,
122
  image2,
123
  attn_splits_list=[2],
@@ -150,8 +153,7 @@ class FlowCalc():
150
  attention_type='swin',
151
  ffn_dim_expansion=4,
152
  num_transformer_layers=6,
153
- ).to('cuda')
154
-
155
  checkpoint = torch.load(model_path,
156
  map_location=lambda storage, loc: storage)
157
  weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
@@ -168,7 +170,8 @@ class FlowCalc():
168
  image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
169
  image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
170
  padder = InputPadder(image1.shape, padding_factor=8)
171
- image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
 
172
  results_dict = self.model(image1,
173
  image2,
174
  attn_splits_list=[2],
 
12
  from gmflow.gmflow import GMFlow # noqa: E702 E402 F401
13
  from utils.utils import InputPadder # noqa: E702 E402
14
 
15
+ global_device = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+
17
 
18
  def coords_grid(b, h, w, homogeneous=False, device=None):
19
  y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
 
29
  grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
30
 
31
  if device is not None:
32
+ grid = grid.to(global_device)
33
 
34
  return grid
35
 
 
119
  if image3 is None:
120
  image3 = image1
121
  padder = InputPadder(image1.shape, padding_factor=8)
122
+ image1, image2 = padder.pad(image1[None].to(global_device),
123
+ image2[None].to(global_device))
124
  results_dict = flow_model(image1,
125
  image2,
126
  attn_splits_list=[2],
 
153
  attention_type='swin',
154
  ffn_dim_expansion=4,
155
  num_transformer_layers=6,
156
+ ).to(global_device)
 
157
  checkpoint = torch.load(model_path,
158
  map_location=lambda storage, loc: storage)
159
  weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
 
170
  image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
171
  image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
172
  padder = InputPadder(image1.shape, padding_factor=8)
173
+ image1, image2 = padder.pad(image1[None].to(global_device),
174
+ image2[None].to(global_device))
175
  results_dict = self.model(image1,
176
  image2,
177
  attn_splits_list=[2],