Optimize MPS to use bfloat16 precision
Browse filesAfter testing, MPS can handle bfloat16 precision without autocast.
This reduces memory usage while maintaining stable output.
Changes from previous commit:
- Use bfloat16 for images on both MPS and CUDA (unified dtype)
- Keep nullcontext() for MPS (no autocast - causes issues)
- CUDA path unchanged (still uses bfloat16 autocast)
Key insight: The row-wise embedding assignment fix was the critical
change. With that in place, bfloat16 works stably on MPS without
needing fp32 precision.
Tested on: macOS 26.0.1, Apple M4 Max, PyTorch 2.9.0
- modeling_deepseekocr.py +6 -6
modeling_deepseekocr.py
CHANGED
|
@@ -816,8 +816,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 816 |
|
| 817 |
|
| 818 |
|
| 819 |
-
# MPS
|
| 820 |
-
image_dtype = torch.
|
| 821 |
images_list.append(image_transform(global_view).to(image_dtype))
|
| 822 |
|
| 823 |
# global_view_tensor = image_transform(global_view).to(torch.bfloat16)
|
|
@@ -865,8 +865,8 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 865 |
# else:
|
| 866 |
global_view = ImageOps.pad(image, (image_size, image_size),
|
| 867 |
color=tuple(int(x * 255) for x in image_transform.mean))
|
| 868 |
-
# MPS
|
| 869 |
-
image_dtype = torch.
|
| 870 |
images_list.append(image_transform(global_view).to(image_dtype))
|
| 871 |
|
| 872 |
if base_size == 1024:
|
|
@@ -932,7 +932,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 932 |
|
| 933 |
if not eval_mode:
|
| 934 |
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 935 |
-
# MPS: no autocast (pure
|
| 936 |
autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
|
| 937 |
with autocast_ctx:
|
| 938 |
with torch.no_grad():
|
|
@@ -952,7 +952,7 @@ class DeepseekOCRForCausalLM(DeepseekV2ForCausalLM):
|
|
| 952 |
)
|
| 953 |
|
| 954 |
else:
|
| 955 |
-
# MPS: no autocast (pure
|
| 956 |
autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
|
| 957 |
with autocast_ctx:
|
| 958 |
with torch.no_grad():
|
|
|
|
| 816 |
|
| 817 |
|
| 818 |
|
| 819 |
+
# MPS and CUDA both use bfloat16
|
| 820 |
+
image_dtype = torch.bfloat16
|
| 821 |
images_list.append(image_transform(global_view).to(image_dtype))
|
| 822 |
|
| 823 |
# global_view_tensor = image_transform(global_view).to(torch.bfloat16)
|
|
|
|
| 865 |
# else:
|
| 866 |
global_view = ImageOps.pad(image, (image_size, image_size),
|
| 867 |
color=tuple(int(x * 255) for x in image_transform.mean))
|
| 868 |
+
# MPS and CUDA both use bfloat16
|
| 869 |
+
image_dtype = torch.bfloat16
|
| 870 |
images_list.append(image_transform(global_view).to(image_dtype))
|
| 871 |
|
| 872 |
if base_size == 1024:
|
|
|
|
| 932 |
|
| 933 |
if not eval_mode:
|
| 934 |
streamer = NoEOSTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=False)
|
| 935 |
+
# MPS: no autocast (pure bfloat16); CUDA: bfloat16 autocast
|
| 936 |
autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
|
| 937 |
with autocast_ctx:
|
| 938 |
with torch.no_grad():
|
|
|
|
| 952 |
)
|
| 953 |
|
| 954 |
else:
|
| 955 |
+
# MPS: no autocast (pure bfloat16); CUDA: bfloat16 autocast
|
| 956 |
autocast_ctx = nullcontext() if self.device.type == "mps" else torch.autocast("cuda", dtype=torch.bfloat16)
|
| 957 |
with autocast_ctx:
|
| 958 |
with torch.no_grad():
|