52Hz commited on
Commit
036bd02
1 Parent(s): 9ae15c5

Update main_test_SRMNet.py

Browse files
Files changed (1) hide show
  1. main_test_SRMNet.py +5 -3
main_test_SRMNet.py CHANGED
@@ -34,12 +34,14 @@ def main():
34
  if len(files) == 0:
35
  raise Exception(f"No files found at {inp_dir}")
36
 
 
 
37
  # Load corresponding models architecture and weights
38
  model = SRMNet()
39
- model.cuda()
40
-
41
- load_checkpoint(model, args.weights)
42
  model.eval()
 
 
43
 
44
  mul = 16
45
  for file_ in files:
 
34
  if len(files) == 0:
35
  raise Exception(f"No files found at {inp_dir}")
36
 
37
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
38
+
39
  # Load corresponding models architecture and weights
40
  model = SRMNet()
41
+ model = model.to(device)
 
 
42
  model.eval()
43
+ load_checkpoint(model, args.weights)
44
+
45
 
46
  mul = 16
47
  for file_ in files: