mkalia commited on
Commit
73180f7
1 Parent(s): a50312e

Update depth_app.py

Browse files
Files changed (1) hide show
  1. depth_app.py +6 -3
depth_app.py CHANGED
@@ -7,8 +7,11 @@ import io
7
  from torchvision import transforms
8
  import matplotlib as mpl
9
  import matplotlib.cm as cm
10
- import networks
11
  from layers import disp_to_depth_no_scaling
 
 
 
 
12
 
13
  # Function to load the model
14
  def load_model(device, model_path):
@@ -16,7 +19,7 @@ def load_model(device, model_path):
16
  encoder_path = os.path.join(model_path, "encoder.pth")
17
  depth_decoder_path = os.path.join(model_path, "depth.pth")
18
 
19
- encoder = networks.ResnetEncoder(18, False)
20
  loaded_dict_enc = torch.load(encoder_path, map_location=device)
21
  feed_height = loaded_dict_enc['height']
22
  feed_width = loaded_dict_enc['width']
@@ -25,7 +28,7 @@ def load_model(device, model_path):
25
  encoder.to(device)
26
  encoder.eval()
27
 
28
- depth_decoder = networks.DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
29
  loaded_dict = torch.load(depth_decoder_path, map_location=device)
30
  depth_decoder.load_state_dict(loaded_dict, strict=False)
31
  depth_decoder.to(device)
 
7
  from torchvision import transforms
8
  import matplotlib as mpl
9
  import matplotlib.cm as cm
 
10
  from layers import disp_to_depth_no_scaling
11
+ from resnet_encoder import ResnetEncoder
12
+ from depth_decoder import DepthDecoder
13
+ # from pose_decoder import PoseDecoder
14
+ # from pose_cnn import PoseCNN
15
 
16
  # Function to load the model
17
  def load_model(device, model_path):
 
19
  encoder_path = os.path.join(model_path, "encoder.pth")
20
  depth_decoder_path = os.path.join(model_path, "depth.pth")
21
 
22
+ encoder = ResnetEncoder(18, False)
23
  loaded_dict_enc = torch.load(encoder_path, map_location=device)
24
  feed_height = loaded_dict_enc['height']
25
  feed_width = loaded_dict_enc['width']
 
28
  encoder.to(device)
29
  encoder.eval()
30
 
31
+ depth_decoder = DepthDecoder(num_ch_enc=encoder.num_ch_enc, scales=range(4))
32
  loaded_dict = torch.load(depth_decoder_path, map_location=device)
33
  depth_decoder.load_state_dict(loaded_dict, strict=False)
34
  depth_decoder.to(device)