Spaces:
Runtime error
Runtime error
Anonymous-sub
commited on
Commit
•
deab087
1
Parent(s):
9c1dc83
Update flow/flow_utils.py
Browse files- flow/flow_utils.py +8 -5
flow/flow_utils.py
CHANGED
@@ -12,6 +12,8 @@ sys.path.insert(0, gmflow_dir)
|
|
12 |
from gmflow.gmflow import GMFlow # noqa: E702 E402 F401
|
13 |
from utils.utils import InputPadder # noqa: E702 E402
|
14 |
|
|
|
|
|
15 |
|
16 |
def coords_grid(b, h, w, homogeneous=False, device=None):
|
17 |
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
@@ -27,7 +29,7 @@ def coords_grid(b, h, w, homogeneous=False, device=None):
|
|
27 |
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
28 |
|
29 |
if device is not None:
|
30 |
-
grid = grid.to(
|
31 |
|
32 |
return grid
|
33 |
|
@@ -117,7 +119,8 @@ def get_warped_and_mask(flow_model,
|
|
117 |
if image3 is None:
|
118 |
image3 = image1
|
119 |
padder = InputPadder(image1.shape, padding_factor=8)
|
120 |
-
image1, image2 = padder.pad(image1[None].
|
|
|
121 |
results_dict = flow_model(image1,
|
122 |
image2,
|
123 |
attn_splits_list=[2],
|
@@ -150,8 +153,7 @@ class FlowCalc():
|
|
150 |
attention_type='swin',
|
151 |
ffn_dim_expansion=4,
|
152 |
num_transformer_layers=6,
|
153 |
-
).to(
|
154 |
-
|
155 |
checkpoint = torch.load(model_path,
|
156 |
map_location=lambda storage, loc: storage)
|
157 |
weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
|
@@ -168,7 +170,8 @@ class FlowCalc():
|
|
168 |
image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
|
169 |
image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
|
170 |
padder = InputPadder(image1.shape, padding_factor=8)
|
171 |
-
image1, image2 = padder.pad(image1[None].
|
|
|
172 |
results_dict = self.model(image1,
|
173 |
image2,
|
174 |
attn_splits_list=[2],
|
|
|
12 |
from gmflow.gmflow import GMFlow # noqa: E702 E402 F401
|
13 |
from utils.utils import InputPadder # noqa: E702 E402
|
14 |
|
15 |
+
global_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
16 |
+
|
17 |
|
18 |
def coords_grid(b, h, w, homogeneous=False, device=None):
|
19 |
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
|
|
29 |
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
30 |
|
31 |
if device is not None:
|
32 |
+
grid = grid.to(global_device)
|
33 |
|
34 |
return grid
|
35 |
|
|
|
119 |
if image3 is None:
|
120 |
image3 = image1
|
121 |
padder = InputPadder(image1.shape, padding_factor=8)
|
122 |
+
image1, image2 = padder.pad(image1[None].to(global_device),
|
123 |
+
image2[None].to(global_device))
|
124 |
results_dict = flow_model(image1,
|
125 |
image2,
|
126 |
attn_splits_list=[2],
|
|
|
153 |
attention_type='swin',
|
154 |
ffn_dim_expansion=4,
|
155 |
num_transformer_layers=6,
|
156 |
+
).to(global_device)
|
|
|
157 |
checkpoint = torch.load(model_path,
|
158 |
map_location=lambda storage, loc: storage)
|
159 |
weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
|
|
|
170 |
image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
|
171 |
image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
|
172 |
padder = InputPadder(image1.shape, padding_factor=8)
|
173 |
+
image1, image2 = padder.pad(image1[None].to(global_device),
|
174 |
+
image2[None].to(global_device))
|
175 |
results_dict = self.model(image1,
|
176 |
image2,
|
177 |
attn_splits_list=[2],
|