Yixuan Li commited on
Commit
9857e4c
·
1 Parent(s): 8221cd6

forcing torch.load to CPU

Browse files
Files changed (1) hide show
  1. app.py +3 -1
app.py CHANGED
@@ -9,7 +9,9 @@ device = 'cuda' if torch.cuda.is_available() else 'cpu'
9
  _old_load = torch.load
10
 
11
  def safe_torch_load(*args, **kwargs):
12
- if 'map_location' not in kwargs:
 
 
13
  kwargs['map_location'] = device
14
  return _old_load(*args, **kwargs)
15
 
 
9
  _old_load = torch.load
10
 
11
  def safe_torch_load(*args, **kwargs):
12
+ if len(args) >= 2:
13
+ args[1] = device
14
+ else:
15
  kwargs['map_location'] = device
16
  return _old_load(*args, **kwargs)
17