mkalia commited on
Commit
d7357cf
1 Parent(s): 30d6ec7

Update depth_decoder.py

Browse files
Files changed (1) hide show
  1. depth_decoder.py +3 -1
depth_decoder.py CHANGED
@@ -18,6 +18,8 @@ class DepthDecoder(nn.Module):
18
  def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True, batch_norm = True):
19
  super(DepthDecoder, self).__init__()
20
 
 
 
21
  self.num_output_channels = num_output_channels
22
  self.use_skips = use_skips
23
  self.upsample_mode = 'nearest'
@@ -70,7 +72,7 @@ class DepthDecoder(nn.Module):
70
  x = torch.cat(x, 1)
71
  x = self.convs[("upconv", i, 1)](x)
72
  if self.batch_norm:
73
- x = self.bn[('bn', i)].cuda()(x)
74
 
75
 
76
  # batchnorm
 
18
  def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True, batch_norm = True):
19
  super(DepthDecoder, self).__init__()
20
 
21
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+
23
  self.num_output_channels = num_output_channels
24
  self.use_skips = use_skips
25
  self.upsample_mode = 'nearest'
 
72
  x = torch.cat(x, 1)
73
  x = self.convs[("upconv", i, 1)](x)
74
  if self.batch_norm:
75
+ x = self.bn[('bn', i)].to(self.device)(x)
76
 
77
 
78
  # batchnorm