fix device error
Browse files- .idea/workspace.xml +2 -3
- utils.py +3 -1
.idea/workspace.xml
CHANGED
@@ -6,8 +6,7 @@
|
|
6 |
<component name="ChangeListManager">
|
7 |
<list default="true" id="9fb9e207-fc4f-4ff3-9adc-3c4c1e67daa7" name="Changes" comment="">
|
8 |
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
9 |
-
<change beforePath="$PROJECT_DIR$/
|
10 |
-
<change beforePath="$PROJECT_DIR$/networks/backbone/__init__.py" beforeDir="false" afterPath="$PROJECT_DIR$/networks/backbone/__init__.py" afterDir="false" />
|
11 |
</list>
|
12 |
<option name="SHOW_DIALOG" value="false" />
|
13 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
@@ -49,7 +48,7 @@
|
|
49 |
<option name="presentableId" value="Default" />
|
50 |
<updated>1664204268713</updated>
|
51 |
<workItem from="1664204270261" duration="37000" />
|
52 |
-
<workItem from="1664204316867" duration="
|
53 |
</task>
|
54 |
<servers />
|
55 |
</component>
|
|
|
6 |
<component name="ChangeListManager">
|
7 |
<list default="true" id="9fb9e207-fc4f-4ff3-9adc-3c4c1e67daa7" name="Changes" comment="">
|
8 |
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
9 |
+
<change beforePath="$PROJECT_DIR$/utils.py" beforeDir="false" afterPath="$PROJECT_DIR$/utils.py" afterDir="false" />
|
|
|
10 |
</list>
|
11 |
<option name="SHOW_DIALOG" value="false" />
|
12 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
|
48 |
<option name="presentableId" value="Default" />
|
49 |
<updated>1664204268713</updated>
|
50 |
<workItem from="1664204270261" duration="37000" />
|
51 |
+
<workItem from="1664204316867" duration="5530000" />
|
52 |
</task>
|
53 |
<servers />
|
54 |
</component>
|
utils.py
CHANGED
@@ -7,8 +7,10 @@ from networks import convert_to_separable_conv, set_bn_momentum
|
|
7 |
|
8 |
def get_network() -> torch.nn.Module:
|
9 |
network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
|
|
|
10 |
state_dict = torch.hub.load_state_dict_from_url(
|
11 |
-
"https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt"
|
|
|
12 |
)
|
13 |
network.backbone.load_state_dict(state_dict, strict=True)
|
14 |
convert_to_separable_conv(network.classifier)
|
|
|
7 |
|
8 |
def get_network() -> torch.nn.Module:
|
9 |
network = deeplabv3plus_resnet50(num_classes=21, pretrained_backbone=False)
|
10 |
+
|
11 |
state_dict = torch.hub.load_state_dict_from_url(
|
12 |
+
"https://www.robots.ox.ac.uk/~vgg/research/namedmask/shared_files/voc2012/namedmask_voc2012.pt",
|
13 |
+
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
)
|
15 |
network.backbone.load_state_dict(state_dict, strict=True)
|
16 |
convert_to_separable_conv(network.classifier)
|