YannisK commited on
Commit
cb6bc7f
1 Parent(s): ce415a7
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -29,10 +29,18 @@ state = torch.load('fire.pth', map_location='cpu')
29
  state['net_params']['pretrained'] = None # no need for imagenet pretrained model
30
  net_sfm = fire_network.init_network(**state['net_params']).to(device)
31
  net_sfm.load_state_dict(state['state_dict'])
 
 
 
 
 
32
 
33
  state2 = torch.load('fire_imagenet.pth', map_location='cpu')
 
 
 
34
  net_imagenet = fire_network.init_network(**state['net_params']).to(device)
35
- net_imagenet.load_state_dict(state2['state_dict'])
36
 
37
  # ---------------------------------------
38
  transform = transforms.Compose([
 
29
  state['net_params']['pretrained'] = None # no need for imagenet pretrained model
30
  net_sfm = fire_network.init_network(**state['net_params']).to(device)
31
  net_sfm.load_state_dict(state['state_dict'])
32
+ dim_red_params_dict = {}
33
+ for name, param in net_sfm.named_parameters():
34
+ if 'dim_reduction' in name:
35
+ dim_red_params_dict[name] = param
36
+
37
 
38
  state2 = torch.load('fire_imagenet.pth', map_location='cpu')
39
+ state2['net_params'] = state['net_params']
40
+ state2['state_dict'] += dim_red_params_dict
41
+ # state2['net_params'] =
42
  net_imagenet = fire_network.init_network(**state['net_params']).to(device)
43
+ net_imagenet.load_state_dict(state2['state_dict']) #, strict=False)
44
 
45
  # ---------------------------------------
46
  transform = transforms.Compose([