noelshin commited on
Commit
bcc8459
·
1 Parent(s): 8513597

fix device error

Browse files
Files changed (2) hide show
  1. .idea/workspace.xml +2 -3
  2. utils.py +3 -1
.idea/workspace.xml CHANGED
@@ -6,8 +6,7 @@
6
  <component name="ChangeListManager">
7
  <list default="true" id="9fb9e207-fc4f-4ff3-9adc-3c4c1e67daa7" name="Changes" comment="">
8
  <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
9
- <change beforePath="$PROJECT_DIR$/networks/_deeplab.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/_deeplab.py" afterDir="false" />
10
- <change beforePath="$PROJECT_DIR$/networks/backbone/__init__.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/backbone/__init__.py" afterDir="false" />
11
  </list>
12
  <option name="SHOW_DIALOG" value="false" />
13
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
@@ -49,7 +48,7 @@
49
  <option name="presentableId" value="Default" />
50
  <updated>1664204268713</updated>
51
  <workItem from="1664204270261" duration="37000" />
52
- <workItem from="1664204316867" duration="5130000" />
53
  </task>
54
  <servers />
55
  </component>
 
6
  <component name="ChangeListManager">
7
  <list default="true" id="9fb9e207-fc4f-4ff3-9adc-3c4c1e67daa7" name="Changes" comment="">
8
  <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
9
+ <change beforePath="$PROJECT_DIR$/utils.py" beforeDir="false" afterPath="$PROJECT_DIR$/utils.py" afterDir="false" />
 
10
  </list>
11
  <option name="SHOW_DIALOG" value="false" />
12
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
 
48
  <option name="presentableId" value="Default" />
49
  <updated>1664204268713</updated>
50
  <workItem from="1664204270261" duration="37000" />
51
+ <workItem from="1664204316867" duration="5530000" />
52
  </task>
53
  <servers />
54
  </component>
utils.py CHANGED
@@ -7,8 +7,10 @@ from networks import convert_to_separable_conv, set_bn_momentum
7
 
8
  def get_network() -> torch.nn.Module:
9
  network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
 
10
  state_dict = torch.hub.load_state_dict_from_url(
11
- "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt"
 
12
  )
13
  network.backbone.load_state_dict(state_dict, strict=True)
14
  convert_to_separable_conv(network.classifier)
 
7
 
8
  def get_network() -> torch.nn.Module:
9
  network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
10
+
11
  state_dict = torch.hub.load_state_dict_from_url(
12
+ "https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
13
+ map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  )
15
  network.backbone.load_state_dict(state_dict, strict=True)
16
  convert_to_separable_conv(network.classifier)