Theob commited on
Commit
fac47b5
·
verified ·
1 Parent(s): 4702f0f

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +1 -1
handler.py CHANGED
@@ -39,7 +39,7 @@ class EndpointHandler:
39
  self.model.eval()
40
  self.model.to(self.device)
41
  self.model = torch.compile(self.model)
42
- dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
43
  ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
44
  self.ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)
45
  print("Warming up hardware 🔥")
 
39
  self.model.eval()
40
  self.model.to(self.device)
41
  self.model = torch.compile(self.model)
42
+ dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and torch.cuda.get_device_capability()[0] >= 8 else 'float16' # 'float32' or 'bfloat16' or 'float16'
43
  ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
44
  self.ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)
45
  print("Warming up hardware 🔥")