Spaces:
Sleeping
Sleeping
Update depth_app.py
Browse files- depth_app.py +6 -3
depth_app.py
CHANGED
@@ -7,8 +7,11 @@ import io
|
|
7 |
from torchvision import transforms
|
8 |
import matplotlib as mpl
|
9 |
import matplotlib.cm as cm
|
10 |
-
import networks
|
11 |
from layers import disp_to_depth_no_scaling
|
|
|
|
|
|
|
|
|
12 |
|
13 |
# Function to load the model
|
14 |
def load_model(device, model_path):
|
@@ -16,7 +19,7 @@ def load_model(device, model_path):
|
|
16 |
encoder_path = os.path.join(model_path, "encoder.pth")
|
17 |
depth_decoder_path = os.path.join(model_path, "depth.pth")
|
18 |
|
19 |
-
encoder =
|
20 |
loaded_dict_enc = torch.load(encoder_path, map_location=device)
|
21 |
feed_height = loaded_dict_enc['height']
|
22 |
feed_width = loaded_dict_enc['width']
|
@@ -25,7 +28,7 @@ def load_model(device, model_path):
|
|
25 |
encoder.to(device)
|
26 |
encoder.eval()
|
27 |
|
28 |
-
depth_decoder =
|
29 |
loaded_dict = torch.load(depth_decoder_path, map_location=device)
|
30 |
depth_decoder.load_state_dict(loaded_dict, strict=False)
|
31 |
depth_decoder.to(device)
|
|
|
7 |
from torchvision import transforms
|
8 |
import matplotlib as mpl
|
9 |
import matplotlib.cm as cm
|
|
|
10 |
from layers import disp_to_depth_no_scaling
|
11 |
+
from resnet_encoder import ResnetEncoder
|
12 |
+
from depth_decoder import DepthDecoder
|
13 |
+
# from pose_decoder import PoseDecoder
|
14 |
+
# from pose_cnn import PoseCNN
|
15 |
|
16 |
# Function to load the model
|
17 |
def load_model(device, model_path):
|
|
|
19 |
encoder_path = os.path.join(model_path, "encoder.pth")
|
20 |
depth_decoder_path = os.path.join(model_path, "depth.pth")
|
21 |
|
22 |
+
encoder = ResnetEncoder(18, False)
|
23 |
loaded_dict_enc = torch.load(encoder_path, map_location=device)
|
24 |
feed_height = loaded_dict_enc['height']
|
25 |
feed_width = loaded_dict_enc['width']
|
|
|
28 |
encoder.to(device)
|
29 |
encoder.eval()
|
30 |
|
31 |
+
depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
|
32 |
loaded_dict = torch.load(depth_decoder_path, map_location=device)
|
33 |
depth_decoder.load_state_dict(loaded_dict, strict=False)
|
34 |
depth_decoder.to(device)
|