cocktailpeanut commited on
Commit
854e73f
·
1 Parent(s): 63bc8dd
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -5,6 +5,8 @@ import numpy as np
5
  from aura_sr import AuraSR
6
  import torch
7
  #import spaces
 
 
8
 
9
  # Force CPU usage
10
  torch.set_default_tensor_type(torch.FloatTensor)
@@ -14,7 +16,7 @@ original_load = torch.load
14
  torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, map_location=torch.device('cpu'))
15
 
16
  # Initialize the AuraSR model
17
- aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR")
18
 
19
  # Restore original torch.load
20
  torch.load = original_load
 
5
  from aura_sr import AuraSR
6
  import torch
7
  #import spaces
8
+ import devicetorch
9
+ device = devicetorch.get(torch)
10
 
11
  # Force CPU usage
12
  torch.set_default_tensor_type(torch.FloatTensor)
 
16
  torch.load = lambda *args, **kwargs: original_load(*args, **kwargs, map_location=torch.device('cpu'))
17
 
18
  # Initialize the AuraSR model
19
+ aura_sr = AuraSR.from_pretrained("fal-ai/AuraSR", device=device)
20
 
21
  # Restore original torch.load
22
  torch.load = original_load