echen01 commited on
Commit
f7bf9fb
1 Parent(s): 5b7158a

fix deeplab device

Browse files
Files changed (2) hide show
  1. criteria/deeplab.py +2 -2
  2. criteria/mask.py +1 -0
criteria/deeplab.py CHANGED
@@ -309,7 +309,7 @@ def resnet50(pretrained=False, **kwargs):
309
  return model
310
 
311
 
312
- def resnet101(path=None, pretrained=False, num_groups=None, weight_std=False, **kwargs):
313
  """Constructs a ResNet-101 model.
314
 
315
  Args:
@@ -326,7 +326,7 @@ def resnet101(path=None, pretrained=False, num_groups=None, weight_std=False, **
326
  model_dict = model.state_dict()
327
  if num_groups and weight_std:
328
  path = os.path.join(os.path.dirname(path), "R-101-GN-WS.pth.tar")
329
- pretrained_dict = torch.load(path)
330
  overlap_dict = {
331
  k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict
332
  }
 
309
  return model
310
 
311
 
312
+ def resnet101(path=None, pretrained=False, num_groups=None, weight_std=False, device="cpu", **kwargs):
313
  """Constructs a ResNet-101 model.
314
 
315
  Args:
 
326
  model_dict = model.state_dict()
327
  if num_groups and weight_std:
328
  path = os.path.join(os.path.dirname(path), "R-101-GN-WS.pth.tar")
329
+ pretrained_dict = torch.load(path, map_location=device)
330
  overlap_dict = {
331
  k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict
332
  }
criteria/mask.py CHANGED
@@ -36,6 +36,7 @@ class Mask(nn.Module):
36
  num_groups=32,
37
  weight_std=True,
38
  beta=False,
 
39
  )
40
  .eval()
41
  .requires_grad_(False)
 
36
  num_groups=32,
37
  weight_std=True,
38
  beta=False,
39
+ device=device,
40
  )
41
  .eval()
42
  .requires_grad_(False)