Vincentqyw
commited on
Commit
•
a1335ed
1
Parent(s):
a9b8ec2
fix: roma cpu
Browse files
app.py
CHANGED
@@ -86,8 +86,7 @@ def ui_reset_state(
|
|
86 |
|
87 |
|
88 |
def run(config):
|
89 |
-
with gr.Blocks(css="footer {visibility: hidden}"
|
90 |
-
) as app:
|
91 |
gr.Markdown(
|
92 |
"""
|
93 |
<p align="center">
|
|
|
86 |
|
87 |
|
88 |
def run(config):
|
89 |
+
with gr.Blocks(css="footer {visibility: hidden}") as app:
|
|
|
90 |
gr.Markdown(
|
91 |
"""
|
92 |
<p align="center">
|
third_party/Roma/roma/models/encoders.py
CHANGED
@@ -6,6 +6,8 @@ import torch.nn.functional as F
|
|
6 |
import torchvision.models as tvm
|
7 |
import gc
|
8 |
|
|
|
|
|
9 |
|
10 |
class ResNet50(nn.Module):
|
11 |
def __init__(
|
@@ -47,7 +49,7 @@ class ResNet50(nn.Module):
|
|
47 |
self.amp_dtype = torch.float32
|
48 |
|
49 |
def forward(self, x, **kwargs):
|
50 |
-
with torch.autocast(
|
51 |
net = self.net
|
52 |
feats = {1: x}
|
53 |
x = net.conv1(x)
|
@@ -90,7 +92,7 @@ class VGG19(nn.Module):
|
|
90 |
self.amp_dtype = torch.float32
|
91 |
|
92 |
def forward(self, x, **kwargs):
|
93 |
-
with torch.autocast(
|
94 |
feats = {}
|
95 |
scale = 1
|
96 |
for layer in self.layers:
|
|
|
6 |
import torchvision.models as tvm
|
7 |
import gc
|
8 |
|
9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
+
|
11 |
|
12 |
class ResNet50(nn.Module):
|
13 |
def __init__(
|
|
|
49 |
self.amp_dtype = torch.float32
|
50 |
|
51 |
def forward(self, x, **kwargs):
|
52 |
+
with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype):
|
53 |
net = self.net
|
54 |
feats = {1: x}
|
55 |
x = net.conv1(x)
|
|
|
92 |
self.amp_dtype = torch.float32
|
93 |
|
94 |
def forward(self, x, **kwargs):
|
95 |
+
with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype):
|
96 |
feats = {}
|
97 |
scale = 1
|
98 |
for layer in self.layers:
|
third_party/Roma/roma/models/matcher.py
CHANGED
@@ -14,6 +14,8 @@ from roma.utils.local_correlation import local_correlation
|
|
14 |
from roma.utils.utils import cls_to_flow_refine
|
15 |
from roma.utils.kde import kde
|
16 |
|
|
|
|
|
17 |
|
18 |
class ConvRefiner(nn.Module):
|
19 |
def __init__(
|
@@ -118,7 +120,7 @@ class ConvRefiner(nn.Module):
|
|
118 |
|
119 |
def forward(self, x, y, flow, scale_factor=1, logits=None):
|
120 |
b, c, hs, ws = x.shape
|
121 |
-
with torch.autocast(
|
122 |
with torch.no_grad():
|
123 |
x_hat = F.grid_sample(
|
124 |
y,
|
@@ -129,8 +131,8 @@ class ConvRefiner(nn.Module):
|
|
129 |
if self.has_displacement_emb:
|
130 |
im_A_coords = torch.meshgrid(
|
131 |
(
|
132 |
-
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=
|
133 |
-
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=
|
134 |
)
|
135 |
)
|
136 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
@@ -423,7 +425,7 @@ class Decoder(nn.Module):
|
|
423 |
corresps[ins] = {}
|
424 |
f1_s, f2_s = f1[ins], f2[ins]
|
425 |
if new_scale in self.proj:
|
426 |
-
with torch.autocast(
|
427 |
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
428 |
|
429 |
if ins in coarse_scales:
|
@@ -643,7 +645,7 @@ class RegressionMatcher(nn.Module):
|
|
643 |
device=None,
|
644 |
):
|
645 |
if device is None:
|
646 |
-
device = torch.device(
|
647 |
from PIL import Image
|
648 |
|
649 |
if isinstance(im_A_path, (str, os.PathLike)):
|
@@ -739,8 +741,8 @@ class RegressionMatcher(nn.Module):
|
|
739 |
# Create im_A meshgrid
|
740 |
im_A_coords = torch.meshgrid(
|
741 |
(
|
742 |
-
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=
|
743 |
-
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=
|
744 |
)
|
745 |
)
|
746 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
14 |
from roma.utils.utils import cls_to_flow_refine
|
15 |
from roma.utils.kde import kde
|
16 |
|
17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
+
|
19 |
|
20 |
class ConvRefiner(nn.Module):
|
21 |
def __init__(
|
|
|
120 |
|
121 |
def forward(self, x, y, flow, scale_factor=1, logits=None):
|
122 |
b, c, hs, ws = x.shape
|
123 |
+
with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype):
|
124 |
with torch.no_grad():
|
125 |
x_hat = F.grid_sample(
|
126 |
y,
|
|
|
131 |
if self.has_displacement_emb:
|
132 |
im_A_coords = torch.meshgrid(
|
133 |
(
|
134 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
135 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
136 |
)
|
137 |
)
|
138 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
425 |
corresps[ins] = {}
|
426 |
f1_s, f2_s = f1[ins], f2[ins]
|
427 |
if new_scale in self.proj:
|
428 |
+
with torch.autocast(device, self.amp_dtype):
|
429 |
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
430 |
|
431 |
if ins in coarse_scales:
|
|
|
645 |
device=None,
|
646 |
):
|
647 |
if device is None:
|
648 |
+
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
649 |
from PIL import Image
|
650 |
|
651 |
if isinstance(im_A_path, (str, os.PathLike)):
|
|
|
741 |
# Create im_A meshgrid
|
742 |
im_A_coords = torch.meshgrid(
|
743 |
(
|
744 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
745 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
746 |
)
|
747 |
)
|
748 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
third_party/Roma/roma/models/transformer/__init__.py
CHANGED
@@ -7,6 +7,8 @@ from .layers.block import Block
|
|
7 |
from .layers.attention import MemEffAttention
|
8 |
from .dinov2 import vit_large
|
9 |
|
|
|
|
|
10 |
|
11 |
class TransformerDecoder(nn.Module):
|
12 |
def __init__(
|
@@ -51,7 +53,7 @@ class TransformerDecoder(nn.Module):
|
|
51 |
return self._scales.copy()
|
52 |
|
53 |
def forward(self, gp_posterior, features, old_stuff, new_scale):
|
54 |
-
with torch.autocast(
|
55 |
B, C, H, W = gp_posterior.shape
|
56 |
x = torch.cat((gp_posterior, features), dim=1)
|
57 |
B, C, H, W = x.shape
|
|
|
7 |
from .layers.attention import MemEffAttention
|
8 |
from .dinov2 import vit_large
|
9 |
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
|
12 |
|
13 |
class TransformerDecoder(nn.Module):
|
14 |
def __init__(
|
|
|
53 |
return self._scales.copy()
|
54 |
|
55 |
def forward(self, gp_posterior, features, old_stuff, new_scale):
|
56 |
+
with torch.autocast(device, dtype=self.amp_dtype, enabled=self.amp):
|
57 |
B, C, H, W = gp_posterior.shape
|
58 |
x = torch.cat((gp_posterior, features), dim=1)
|
59 |
B, C, H, W = x.shape
|
third_party/Roma/roma/utils/local_correlation.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
|
|
|
|
|
4 |
|
5 |
def local_correlation(
|
6 |
feature0,
|
@@ -20,8 +22,8 @@ def local_correlation(
|
|
20 |
# If flow is None, assume feature0 and feature1 are aligned
|
21 |
coords = torch.meshgrid(
|
22 |
(
|
23 |
-
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=
|
24 |
-
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=
|
25 |
)
|
26 |
)
|
27 |
coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(B, h, w, 2)
|
@@ -30,10 +32,10 @@ def local_correlation(
|
|
30 |
local_window = torch.meshgrid(
|
31 |
(
|
32 |
torch.linspace(
|
33 |
-
-2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device=
|
34 |
),
|
35 |
torch.linspace(
|
36 |
-
-2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device=
|
37 |
),
|
38 |
)
|
39 |
)
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
|
4 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
5 |
+
|
6 |
|
7 |
def local_correlation(
|
8 |
feature0,
|
|
|
22 |
# If flow is None, assume feature0 and feature1 are aligned
|
23 |
coords = torch.meshgrid(
|
24 |
(
|
25 |
+
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
26 |
+
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
|
27 |
)
|
28 |
)
|
29 |
coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(B, h, w, 2)
|
|
|
32 |
local_window = torch.meshgrid(
|
33 |
(
|
34 |
torch.linspace(
|
35 |
+
-2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device=device
|
36 |
),
|
37 |
torch.linspace(
|
38 |
+
-2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device=device
|
39 |
),
|
40 |
)
|
41 |
)
|