Spaces:
Build error
Build error
jhj0517
commited on
Commit
•
e4defb0
1
Parent(s):
492539d
better device
Browse files- modules/sam.py +2 -1
modules/sam.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
2 |
import os
|
|
|
3 |
|
4 |
from modules.mask_utils import *
|
5 |
from modules.model_downloader import *
|
@@ -9,7 +10,7 @@ class SamInference:
|
|
9 |
def __init__(self):
|
10 |
self.model = None
|
11 |
self.model_path = f"models/sam_vit_h_4b8939.pth"
|
12 |
-
self.device = "cpu"
|
13 |
self.mask_generator = None
|
14 |
|
15 |
# Tuable Parameters , All default values
|
|
|
1 |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
2 |
import os
|
3 |
+
import torch
|
4 |
|
5 |
from modules.mask_utils import *
|
6 |
from modules.model_downloader import *
|
|
|
10 |
def __init__(self):
|
11 |
self.model = None
|
12 |
self.model_path = f"models/sam_vit_h_4b8939.pth"
|
13 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
self.mask_generator = None
|
15 |
|
16 |
# Tuable Parameters , All default values
|