vtontry / model /mps_utils.py
samirk08's picture
Upload folder using huggingface_hub
8285881 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class MPSAdaptiveAvgPool2d(nn.Module):
"""
A wrapper around AdaptiveAvgPool2d that falls back to CPU when running on MPS
and the input/output size combination is not supported.
"""
def __init__(self, output_size):
super().__init__()
self.output_size = output_size
self.pool = nn.AdaptiveAvgPool2d(output_size)
def forward(self, x):
if x.device.type == 'mps':
# Check if the operation is supported on MPS
h, w = x.shape[2], x.shape[3]
if isinstance(self.output_size, tuple):
out_h, out_w = self.output_size
else:
out_h = out_w = self.output_size
# MPS requires input sizes to be divisible by output sizes
if h % out_h != 0 or w % out_w != 0:
# Fallback to CPU for this operation
device = x.device
x_cpu = x.cpu()
output_cpu = self.pool(x_cpu)
return output_cpu.to(device)
# Use normal pooling for CUDA or when MPS is supported
return self.pool(x)