fix(modeling): device-agnostic MPS/CPU support — replace .cuda() with .to(device), fix masked_scatter_ broadcast, and use device-type-aware autocast

#5
by kushdab - opened

Problem

The model is hard-coded to CUDA in three ways that cause crashes on Apple Silicon (MPS) and CPU:

1. masked_scatter_ broadcast + .cuda() (line 582)

# Before
inputs_embeds[idx].masked_scatter_(
    images_seq_mask[idx].unsqueeze(-1).cuda(), images_in_this_batch
)

Two bugs in one line:

  • .unsqueeze(-1) leaves the mask with shape [seq_len, 1] while inputs_embeds[idx] has shape [seq_len, hidden_dim]. masked_scatter_ requires the mask to be broadcastable but on MPS it raises RuntimeError: masked_scatter_ received a mask with wrong number of elements. expand_as() makes the broadcast explicit and safe.
  • .cuda() hardcodes the device and crashes immediately on MPS/CPU with RuntimeError: Expected all tensors to be on the same device.

2. All remaining .cuda() calls in infer() and infer_multi() (17 occurrences)

Every tensor sent to generate() is moved with .cuda(). On MPS these fail with Expected all tensors to be on the same device, but found at least two devices, mps:0 and cuda:0.

3. torch.autocast("cuda", ...) (3 occurrences)

torch.autocast takes the device type as its first argument. Passing "cuda" on an MPS machine raises RuntimeError: "cuda" is not a supported autocast device type on this platform.

Fix

# At the start of infer() and infer_multi():
device = next(self.parameters()).device  # MPS / CUDA / CPU

# masked_scatter_ fix — expand mask + move to correct device:
_mask = images_seq_mask[idx].unsqueeze(-1).expand_as(inputs_embeds[idx]).to(inputs_embeds.device)
inputs_embeds[idx].masked_scatter_(_mask, images_in_this_batch)

# All other .cuda() → .to(device)
input_ids.unsqueeze(0).to(device)  # etc.

# autocast → device-type-aware:
torch.autocast(device.type, dtype=torch.bfloat16)

Testing

Validated on CUDA (no regression). The expand_as path is a strict superset of the previous behaviour — on CUDA the expanded mask is identical to what the implicit broadcast would have produced.

Fixes: https://github.com/baidu/Unlimited-OCR/issues/18
Related: https://github.com/baidu/Unlimited-OCR/pull/29

Ready to merge
This branch is ready to get merged automatically.

Sign up or log in to comment