| | """
|
| | MPS optimizations for Vortex model on Apple Silicon.
|
| | Uses PyTorch MPS backend with MPS-compatible ops only.
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from typing import Optional, Dict, Any
|
| |
|
| |
|
| | def optimize_for_mps(
|
| | model: nn.Module,
|
| | config: Dict,
|
| | use_sdpa: bool = True,
|
| | ) -> nn.Module:
|
| | """
|
| | Apply MPS optimizations to model.
|
| |
|
| | Args:
|
| | model: VortexModel
|
| | config: Model config
|
| | use_sdpa: Use PyTorch scaled dot product attention (MPS compatible)
|
| |
|
| | Returns:
|
| | Optimized model
|
| | """
|
| | device = torch.device("mps")
|
| |
|
| |
|
| | model = model.to(device)
|
| |
|
| |
|
| | dtype_str = config.get("dtype", "bfloat16")
|
| | if dtype_str == "bfloat16":
|
| |
|
| | dtype = torch.float16
|
| | else:
|
| | dtype = torch.float32
|
| |
|
| | model = model.to(dtype)
|
| |
|
| |
|
| | if use_sdpa:
|
| | model = _apply_sdpa(model)
|
| | print("Applied PyTorch SDPA for MPS")
|
| |
|
| | return model
|
| |
|
| |
|
| | def _apply_sdpa(model: nn.Module) -> nn.Module:
|
| | """
|
| | Replace custom attention with PyTorch SDPA.
|
| | SDPA is optimized for MPS backend.
|
| | """
|
| | for name, module in model.named_modules():
|
| | if hasattr(module, 'attn') and hasattr(module.attn, 'forward_optimized'):
|
| |
|
| | original_forward = module.attn.forward
|
| |
|
| | def sdpa_forward(self, x, *args, **kwargs):
|
| | return self._standard_attention(x, kwargs.get('attention_mask'))
|
| |
|
| | module.attn.forward = sdpa_forward.__get__(module.attn, type(module.attn))
|
| |
|
| | return model
|
| |
|
| |
|
| | def get_mps_memory_usage() -> Dict[str, float]:
|
| | """Get current MPS memory usage in GB."""
|
| | if not torch.backends.mps.is_available():
|
| | return {"error": "MPS not available"}
|
| |
|
| |
|
| | import psutil
|
| | process = psutil.Process()
|
| | memory_info = process.memory_info()
|
| |
|
| | return {
|
| | "rss_gb": memory_info.rss / 1e9,
|
| | "vms_gb": memory_info.vms / 1e9,
|
| | }
|
| |
|
| |
|
| | def profile_model_mps(
|
| | model: nn.Module,
|
| | input_ids: torch.Tensor,
|
| | num_warmup: int = 10,
|
| | num_runs: int = 50,
|
| | ) -> Dict[str, float]:
|
| | """
|
| | Profile model performance on MPS.
|
| |
|
| | Args:
|
| | model: Model to profile
|
| | input_ids: Example input
|
| | num_warmup: Number of warmup runs
|
| | num_runs: Number of profiling runs
|
| |
|
| | Returns:
|
| | Dictionary with timing statistics
|
| | """
|
| | model.eval()
|
| | device = next(model.parameters()).device
|
| | input_ids = input_ids.to(device)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | for _ in range(num_warmup):
|
| | _ = model(input_ids)
|
| |
|
| | if device.type == "mps":
|
| | torch.mps.synchronize()
|
| |
|
| |
|
| | if device.type == "mps":
|
| | torch.mps.synchronize()
|
| | import time
|
| | start = time.time()
|
| |
|
| | with torch.no_grad():
|
| | for _ in range(num_runs):
|
| | _ = model(input_ids)
|
| | if device.type == "mps":
|
| | torch.mps.synchronize()
|
| |
|
| | elapsed = time.time() - start
|
| |
|
| | avg_time = elapsed / num_runs
|
| | tokens_per_sec = input_ids.shape[1] / avg_time
|
| |
|
| | return {
|
| | "avg_time_sec": avg_time,
|
| | "tokens_per_sec": tokens_per_sec,
|
| | }
|
| |
|
| |
|
| | def test_mps_optimize():
|
| | """Test MPS optimizations."""
|
| | if not torch.backends.mps.is_available():
|
| | print("MPS not available, skipping test")
|
| | return
|
| |
|
| | from models.vortex_model import VortexModel
|
| | from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
| |
|
| | config = VORTEX_7B_CONFIG.copy()
|
| | config["d_model"] = 512
|
| | config["num_layers"] = 2
|
| | config["num_heads"] = 8
|
| | config["vocab_size"] = 1000
|
| |
|
| | model = VortexModel(config)
|
| | print(f"Model parameters: {model.get_num_params():,}")
|
| |
|
| |
|
| | model = optimize_for_mps(model, config, use_sdpa=True)
|
| |
|
| |
|
| | batch_size = 2
|
| | seq_len = 128
|
| | input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).to("mps")
|
| |
|
| | with torch.no_grad():
|
| | output = model(input_ids)
|
| | logits = output["logits"]
|
| |
|
| | print(f"Output shape: {logits.shape}")
|
| | print("MPS optimize test passed!")
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | test_mps_optimize()
|
| |
|