noelshin commited on
Commit
6d6f3c6
·
1 Parent(s): bcc8459

fix state dict loading error

Browse files
Files changed (2) hide show
  1. .idea/workspace.xml +1 -1
  2. utils.py +4 -3
.idea/workspace.xml CHANGED
@@ -48,7 +48,7 @@
48
  <option name="presentableId" value="Default" />
49
  <updated>1664204268713</updated>
50
  <workItem from="1664204270261" duration="37000" />
51
- <workItem from="1664204316867" duration="5530000" />
52
  </task>
53
  <servers />
54
  </component>
 
48
  <option name="presentableId" value="Default" />
49
  <updated>1664204268713</updated>
50
  <workItem from="1664204270261" duration="37000" />
51
+ <workItem from="1664204316867" duration="5840000" />
52
  </task>
53
  <servers />
54
  </component>
utils.py CHANGED
@@ -8,13 +8,14 @@ from networks import convert_to_separable_conv, set_bn_momentum
8
  def get_network() -> torch.nn.Module:
9
  network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
10
 
 
 
 
11
  state_dict = torch.hub.load_state_dict_from_url(
12
  "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
13
  map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  )
15
- network.backbone.load_state_dict(state_dict, strict=True)
16
- convert_to_separable_conv(network.classifier)
17
- set_bn_momentum(network.backbone, momentum=0.01)
18
  return network
19
 
20
 
 
8
  def get_network() -> torch.nn.Module:
9
  network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
10
 
11
+ convert_to_separable_conv(network.classifier)
12
+ set_bn_momentum(network.backbone, momentum=0.01)
13
+
14
  state_dict = torch.hub.load_state_dict_from_url(
15
  "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
16
  map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  )
18
+ network.load_state_dict(state_dict, strict=True)
 
 
19
  return network
20
 
21