asdasdasdasd commited on
Commit
e71c28e
1 Parent(s): aa85de6

Update detect_from_videos.py

Browse files
Files changed (1) hide show
  1. detect_from_videos.py +4 -0
detect_from_videos.py CHANGED
@@ -11,6 +11,9 @@ from tqdm import tqdm
11
  from model_core import Two_Stream_Net
12
  from torchvision import transforms
13
 
 
 
 
14
  xception_default_data_transforms_256 = {
15
  'train': transforms.Compose([
16
  transforms.Resize((256, 256)),
@@ -148,6 +151,7 @@ def test_full_image_network(video_path, model_path, output_path,
148
  # model, *_ = model_selection(modelname='xception', num_out_classes=2)
149
  model = Two_Stream_Net()
150
  model.load_state_dict(torch.load(model_path))
 
151
  model.eval()
152
 
153
  if cuda:
 
11
  from model_core import Two_Stream_Net
12
  from torchvision import transforms
13
 
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+
17
  xception_default_data_transforms_256 = {
18
  'train': transforms.Compose([
19
  transforms.Resize((256, 256)),
 
151
  # model, *_ = model_selection(modelname='xception', num_out_classes=2)
152
  model = Two_Stream_Net()
153
  model.load_state_dict(torch.load(model_path))
154
+ model = model.to(device)
155
  model.eval()
156
 
157
  if cuda: