Realcat commited on
Commit
7ee27ee
1 Parent(s): 368319c

fix: cpu roma

Browse files
third_party/Roma/roma/models/encoders.py CHANGED
@@ -24,7 +24,10 @@ class ResNet50(nn.Module):
24
  self.freeze_bn = freeze_bn
25
  self.early_exit = early_exit
26
  self.amp = amp
27
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
 
 
 
28
 
29
  def forward(self, x, **kwargs):
30
  with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
@@ -60,7 +63,10 @@ class VGG19(nn.Module):
60
  super().__init__()
61
  self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
62
  self.amp = amp
63
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
 
 
 
64
 
65
  def forward(self, x, **kwargs):
66
  with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
@@ -94,7 +100,10 @@ class CNNandDinov2(nn.Module):
94
  else:
95
  self.cnn = VGG19(**cnn_kwargs)
96
  self.amp = amp
97
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
 
 
 
98
  if self.amp:
99
  dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
100
  self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
 
24
  self.freeze_bn = freeze_bn
25
  self.early_exit = early_exit
26
  self.amp = amp
27
+ if not torch.cuda.is_available():
28
+ self.amp_dtype = torch.float32
29
+ else:
30
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
31
 
32
  def forward(self, x, **kwargs):
33
  with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
 
63
  super().__init__()
64
  self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
65
  self.amp = amp
66
+ if not torch.cuda.is_available():
67
+ self.amp_dtype = torch.float32
68
+ else:
69
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
70
 
71
  def forward(self, x, **kwargs):
72
  with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
 
100
  else:
101
  self.cnn = VGG19(**cnn_kwargs)
102
  self.amp = amp
103
+ if not torch.cuda.is_available():
104
+ self.amp_dtype = torch.float32
105
+ else:
106
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
107
  if self.amp:
108
  dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
109
  self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
third_party/Roma/roma/models/matcher.py CHANGED
@@ -71,8 +71,12 @@ class ConvRefiner(nn.Module):
71
  self.disable_local_corr_grad = disable_local_corr_grad
72
  self.is_classifier = is_classifier
73
  self.sample_mode = sample_mode
74
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
75
-
 
 
 
 
76
  def create_block(
77
  self,
78
  in_dim,
@@ -109,8 +113,8 @@ class ConvRefiner(nn.Module):
109
  if self.has_displacement_emb:
110
  im_A_coords = torch.meshgrid(
111
  (
112
- torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
113
- torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
114
  )
115
  )
116
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
@@ -296,8 +300,11 @@ class Decoder(nn.Module):
296
  self.displacement_dropout_p = displacement_dropout_p
297
  self.gm_warp_dropout_p = gm_warp_dropout_p
298
  self.flow_upsample_mode = flow_upsample_mode
299
- self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
300
-
 
 
 
301
  def get_placeholder_flow(self, b, h, w, device):
302
  coarse_coords = torch.meshgrid(
303
  (
@@ -615,8 +622,8 @@ class RegressionMatcher(nn.Module):
615
  # Create im_A meshgrid
616
  im_A_coords = torch.meshgrid(
617
  (
618
- torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
619
- torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
620
  )
621
  )
622
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
 
71
  self.disable_local_corr_grad = disable_local_corr_grad
72
  self.is_classifier = is_classifier
73
  self.sample_mode = sample_mode
74
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
75
+ if not torch.cuda.is_available():
76
+ self.amp_dtype = torch.float32
77
+ else:
78
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
79
+
80
  def create_block(
81
  self,
82
  in_dim,
 
113
  if self.has_displacement_emb:
114
  im_A_coords = torch.meshgrid(
115
  (
116
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=self.device),
117
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=self.device),
118
  )
119
  )
120
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
 
300
  self.displacement_dropout_p = displacement_dropout_p
301
  self.gm_warp_dropout_p = gm_warp_dropout_p
302
  self.flow_upsample_mode = flow_upsample_mode
303
+ if not torch.cuda.is_available():
304
+ self.amp_dtype = torch.float32
305
+ else:
306
+ self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
307
+
308
  def get_placeholder_flow(self, b, h, w, device):
309
  coarse_coords = torch.meshgrid(
310
  (
 
622
  # Create im_A meshgrid
623
  im_A_coords = torch.meshgrid(
624
  (
625
+ torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
626
+ torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
627
  )
628
  )
629
  im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))