Spaces:
Runtime error
Runtime error
cocktailpeanut
commited on
Commit
·
854e73f
1
Parent(s):
63bc8dd
update
Browse files
app.py
CHANGED
@@ -5,6 +5,8 @@ import numpy as np
|
|
5 |
from aura_sr import AuraSR
|
6 |
import torch
|
7 |
#import spaces
|
|
|
|
|
8 |
|
9 |
# Force CPU usage
|
10 |
torch.set_default_tensor_type(torch.FloatTensor)
|
@@ -14,7 +16,7 @@ original_load = torch.load
|
|
14 |
torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, map_location=torch.device('cpu'))
|
15 |
|
16 |
# Initialize the AuraSR model
|
17 |
-
aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR")
|
18 |
|
19 |
# Restore original torch.load
|
20 |
torch.load = original_load
|
|
|
5 |
from aura_sr import AuraSR
|
6 |
import torch
|
7 |
#import spaces
|
8 |
+
import devicetorch
|
9 |
+
device = devicetorch.get(torch)
|
10 |
|
11 |
# Force CPU usage
|
12 |
torch.set_default_tensor_type(torch.FloatTensor)
|
|
|
16 |
torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, map_location=torch.device('cpu'))
|
17 |
|
18 |
# Initialize the AuraSR model
|
19 |
+
aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR", device=device)
|
20 |
|
21 |
# Restore original torch.load
|
22 |
torch.load = original_load
|