Vincentqyw commited on
Commit
8af5ecd
·
1 Parent(s): 94d81f4

update: roma

Browse files
third_party/Roma/roma/models/encoders.py CHANGED
@@ -38,10 +38,13 @@ class ResNet50(nn.Module):
38
  self.freeze_bn = freeze_bn
39
  self.early_exit = early_exit
40
  self.amp = amp
41
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
42
- self.amp_dtype = torch.bfloat16
 
 
 
43
  else:
44
- self.amp_dtype = torch.float16
45
 
46
  def forward(self, x, **kwargs):
47
  with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
@@ -78,10 +81,13 @@ class VGG19(nn.Module):
78
  super().__init__()
79
  self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
80
  self.amp = amp
81
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
82
- self.amp_dtype = torch.bfloat16
 
 
 
83
  else:
84
- self.amp_dtype = torch.float16
85
 
86
  def forward(self, x, **kwargs):
87
  with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
@@ -121,10 +127,13 @@ class CNNandDinov2(nn.Module):
121
  else:
122
  self.cnn = VGG19(**cnn_kwargs)
123
  self.amp = amp
124
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
125
- self.amp_dtype = torch.bfloat16
 
 
 
126
  else:
127
- self.amp_dtype = torch.float16
128
  if self.amp:
129
  dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
130
  self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
 
38
  self.freeze_bn = freeze_bn
39
  self.early_exit = early_exit
40
  self.amp = amp
41
+ if torch.cuda.is_available():
42
+ if torch.cuda.is_bf16_supported():
43
+ self.amp_dtype = torch.bfloat16
44
+ else:
45
+ self.amp_dtype = torch.float16
46
  else:
47
+ self.amp_dtype = torch.float32
48
 
49
  def forward(self, x, **kwargs):
50
  with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
 
81
  super().__init__()
82
  self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
83
  self.amp = amp
84
+ if torch.cuda.is_available():
85
+ if torch.cuda.is_bf16_supported():
86
+ self.amp_dtype = torch.bfloat16
87
+ else:
88
+ self.amp_dtype = torch.float16
89
  else:
90
+ self.amp_dtype = torch.float32
91
 
92
  def forward(self, x, **kwargs):
93
  with torch.autocast("cuda", enabled=self.amp, dtype=self.amp_dtype):
 
127
  else:
128
  self.cnn = VGG19(**cnn_kwargs)
129
  self.amp = amp
130
+ if torch.cuda.is_available():
131
+ if torch.cuda.is_bf16_supported():
132
+ self.amp_dtype = torch.bfloat16
133
+ else:
134
+ self.amp_dtype = torch.float16
135
  else:
136
+ self.amp_dtype = torch.float32
137
  if self.amp:
138
  dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
139
  self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
third_party/Roma/roma/models/matcher.py CHANGED
@@ -76,10 +76,13 @@ class ConvRefiner(nn.Module):
76
  self.disable_local_corr_grad = disable_local_corr_grad
77
  self.is_classifier = is_classifier
78
  self.sample_mode = sample_mode
79
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
80
- self.amp_dtype = torch.bfloat16
 
 
 
81
  else:
82
- self.amp_dtype = torch.float16
83
 
84
  def create_block(
85
  self,
@@ -337,10 +340,13 @@ class Decoder(nn.Module):
337
  self.displacement_dropout_p = displacement_dropout_p
338
  self.gm_warp_dropout_p = gm_warp_dropout_p
339
  self.flow_upsample_mode = flow_upsample_mode
340
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
341
- self.amp_dtype = torch.bfloat16
 
 
 
342
  else:
343
- self.amp_dtype = torch.float16
344
 
345
  def get_placeholder_flow(self, b, h, w, device):
346
  coarse_coords = torch.meshgrid(
 
76
  self.disable_local_corr_grad = disable_local_corr_grad
77
  self.is_classifier = is_classifier
78
  self.sample_mode = sample_mode
79
+ if torch.cuda.is_available():
80
+ if torch.cuda.is_bf16_supported():
81
+ self.amp_dtype = torch.bfloat16
82
+ else:
83
+ self.amp_dtype = torch.float16
84
  else:
85
+ self.amp_dtype = torch.float32
86
 
87
  def create_block(
88
  self,
 
340
  self.displacement_dropout_p = displacement_dropout_p
341
  self.gm_warp_dropout_p = gm_warp_dropout_p
342
  self.flow_upsample_mode = flow_upsample_mode
343
+ if torch.cuda.is_available():
344
+ if torch.cuda.is_bf16_supported():
345
+ self.amp_dtype = torch.bfloat16
346
+ else:
347
+ self.amp_dtype = torch.float16
348
  else:
349
+ self.amp_dtype = torch.float32
350
 
351
  def get_placeholder_flow(self, b, h, w, device):
352
  coarse_coords = torch.meshgrid(
third_party/Roma/roma/models/transformer/__init__.py CHANGED
@@ -30,10 +30,14 @@ class TransformerDecoder(nn.Module):
30
  self._scales = [16]
31
  self.is_classifier = is_classifier
32
  self.amp = amp
33
- if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
34
- self.amp_dtype = torch.bfloat16
 
 
 
35
  else:
36
- self.amp_dtype = torch.float16
 
37
  self.pos_enc = pos_enc
38
  self.learned_embeddings = learned_embeddings
39
  if self.learned_embeddings:
 
30
  self._scales = [16]
31
  self.is_classifier = is_classifier
32
  self.amp = amp
33
+ if torch.cuda.is_available():
34
+ if torch.cuda.is_bf16_supported():
35
+ self.amp_dtype = torch.bfloat16
36
+ else:
37
+ self.amp_dtype = torch.float16
38
  else:
39
+ self.amp_dtype = torch.float32
40
+
41
  self.pos_enc = pos_enc
42
  self.learned_embeddings = learned_embeddings
43
  if self.learned_embeddings: