Oysiyl commited on
Commit
05761ac
1 Parent(s): 4fa08cf

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -4
handler.py CHANGED
@@ -11,10 +11,11 @@ import numpy as np
11
 
12
  # set device
13
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
- if device.type != 'cuda':
15
- raise ValueError("need to run on GPU")
16
- # set mixed precision dtype
17
- dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
 
18
 
19
  class EndpointHandler():
20
  def __init__(self, path=""):
 
11
 
12
  # set device
13
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
+ if torch.cuda.is_available():
15
+ # set mixed precision dtype
16
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
17
+ else:
18
+ dtype = torch.float32
19
 
20
  class EndpointHandler():
21
  def __init__(self, path=""):