jhj0517 commited on
Commit
e4defb0
1 Parent(s): 492539d

better device

Browse files
Files changed (1) hide show
  1. modules/sam.py +2 -1
modules/sam.py CHANGED
@@ -1,5 +1,6 @@
1
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
2
  import os
 
3
 
4
  from modules.mask_utils import *
5
  from modules.model_downloader import *
@@ -9,7 +10,7 @@ class SamInference:
9
  def __init__(self):
10
  self.model = None
11
  self.model_path = f"models/sam_vit_h_4b8939.pth"
12
- self.device = "cpu"
13
  self.mask_generator = None
14
 
15
  # Tuable Parameters , All default values
 
1
  from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
2
  import os
3
+ import torch
4
 
5
  from modules.mask_utils import *
6
  from modules.model_downloader import *
 
10
  def __init__(self):
11
  self.model = None
12
  self.model_path = f"models/sam_vit_h_4b8939.pth"
13
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  self.mask_generator = None
15
 
16
  # Tuable Parameters , All default values