Update handler.py
Browse files- handler.py +1 -1
handler.py
CHANGED
|
@@ -39,7 +39,7 @@ class EndpointHandler:
|
|
| 39 |
self.model.eval()
|
| 40 |
self.model.to(self.device)
|
| 41 |
self.model = torch.compile(self.model)
|
| 42 |
-
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16'
|
| 43 |
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
|
| 44 |
self.ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)
|
| 45 |
print("Warming up hardware 🔥")
|
|
|
|
| 39 |
self.model.eval()
|
| 40 |
self.model.to(self.device)
|
| 41 |
self.model = torch.compile(self.model)
|
| 42 |
+
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() and torch.cuda.get_device_capability()[0] >= 8 else 'float16' # 'float32' or 'bfloat16' or 'float16'
|
| 43 |
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
|
| 44 |
self.ctx = nullcontext() if device == 'cpu' else torch.amp.autocast(device_type=device, dtype=ptdtype)
|
| 45 |
print("Warming up hardware 🔥")
|