Update audiosr/pipeline.py
Browse files- audiosr/pipeline.py +2 -2
audiosr/pipeline.py
CHANGED
@@ -118,8 +118,8 @@ def round_up_duration(duration):
|
|
118 |
def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
|
119 |
if device is None or device == "auto":
|
120 |
if torch.cuda.is_available():
|
121 |
-
device = torch.Tensor([0]).cuda()
|
122 |
-
|
123 |
elif torch.backends.mps.is_available():
|
124 |
device = torch.device("mps")
|
125 |
else:
|
|
|
118 |
def build_model(ckpt_path=None, config=None, device=None, model_name="basic"):
|
119 |
if device is None or device == "auto":
|
120 |
if torch.cuda.is_available():
|
121 |
+
# device = torch.Tensor([0]).cuda()
|
122 |
+
device = torch.device("cuda:0")
|
123 |
elif torch.backends.mps.is_available():
|
124 |
device = torch.device("mps")
|
125 |
else:
|