Spaces:
Sleeping
Sleeping
Update depth_decoder.py
Browse files- depth_decoder.py +3 -1
depth_decoder.py
CHANGED
@@ -18,6 +18,8 @@ class DepthDecoder(nn.Module):
|
|
18 |
def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True, batch_norm = True):
|
19 |
super(DepthDecoder, self).__init__()
|
20 |
|
|
|
|
|
21 |
self.num_output_channels = num_output_channels
|
22 |
self.use_skips = use_skips
|
23 |
self.upsample_mode = 'nearest'
|
@@ -70,7 +72,7 @@ class DepthDecoder(nn.Module):
|
|
70 |
x = torch.cat(x, 1)
|
71 |
x = self.convs[("upconv", i, 1)](x)
|
72 |
if self.batch_norm:
|
73 |
-
x = self.bn[('bn', i)].
|
74 |
|
75 |
|
76 |
# batchnorm
|
|
|
18 |
def __init__(self, num_ch_enc, scales=range(4), num_output_channels=1, use_skips=True, batch_norm = True):
|
19 |
super(DepthDecoder, self).__init__()
|
20 |
|
21 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
22 |
+
|
23 |
self.num_output_channels = num_output_channels
|
24 |
self.use_skips = use_skips
|
25 |
self.upsample_mode = 'nearest'
|
|
|
72 |
x = torch.cat(x, 1)
|
73 |
x = self.convs[("upconv", i, 1)](x)
|
74 |
if self.batch_norm:
|
75 |
+
x = self.bn[('bn', i)].to(self.device)(x)
|
76 |
|
77 |
|
78 |
# batchnorm
|