import torch import logging from packaging import version import torch.backends import torch.backends.mps logger = logging.getLogger(__name__) def check_for_mps() -> bool: if version.parse(torch.__version__) <= version.parse("2.0.1"): if not getattr(torch, "has_mps", False): return False try: torch.zeros(1).to(torch.device("mps")) return True except Exception: return False else: try: return torch.backends.mps.is_available() and torch.backends.mps.is_built() except: logger.warning("MPS garbage collection failed", exc_info=True) return False has_mps = check_for_mps() def torch_mps_gc() -> None: try: from torch.mps import empty_cache empty_cache() except Exception: logger.warning("MPS garbage collection failed", exc_info=True) if __name__ == "__main__": print(torch.__version__) print(has_mps) torch_mps_gc()