| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class MPSAdaptiveAvgPool2d(nn.Module): | |
| """ | |
| A wrapper around AdaptiveAvgPool2d that falls back to CPU when running on MPS | |
| and the input/output size combination is not supported. | |
| """ | |
| def __init__(self, output_size): | |
| super().__init__() | |
| self.output_size = output_size | |
| self.pool = nn.AdaptiveAvgPool2d(output_size) | |
| def forward(self, x): | |
| if x.device.type == 'mps': | |
| # Check if the operation is supported on MPS | |
| h, w = x.shape[2], x.shape[3] | |
| if isinstance(self.output_size, tuple): | |
| out_h, out_w = self.output_size | |
| else: | |
| out_h = out_w = self.output_size | |
| # MPS requires input sizes to be divisible by output sizes | |
| if h % out_h != 0 or w % out_w != 0: | |
| # Fallback to CPU for this operation | |
| device = x.device | |
| x_cpu = x.cpu() | |
| output_cpu = self.pool(x_cpu) | |
| return output_cpu.to(device) | |
| # Use normal pooling for CUDA or when MPS is supported | |
| return self.pool(x) |