| """
|
| CUDA optimizations for Vortex model on Nvidia 4060 laptop.
|
| Flash Attention 2, torch.compile, INT8 quantization.
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| from typing import Optional, Dict, Any
|
|
|
|
|
| def optimize_for_cuda(
|
| model: nn.Module,
|
| config: Dict,
|
| use_flash_attention: bool = True,
|
| use_torch_compile: bool = True,
|
| compile_mode: str = "reduce-overhead",
|
| quantization: Optional[str] = None,
|
| ) -> nn.Module:
|
| """
|
| Apply CUDA optimizations to model.
|
|
|
| Args:
|
| model: VortexModel
|
| config: Model config
|
| use_flash_attention: Enable Flash Attention 2
|
| use_torch_compile: Use torch.compile
|
| compile_mode: Compile mode ("reduce-overhead", "max-autotune")
|
| quantization: None, "int8", or "int4"
|
|
|
| Returns:
|
| Optimized model
|
| """
|
| device = torch.device("cuda")
|
|
|
|
|
| model = model.to(device)
|
|
|
|
|
| dtype_str = config.get("dtype", "bfloat16")
|
| if dtype_str == "bfloat16":
|
| dtype = torch.bfloat16
|
| elif dtype_str == "float16":
|
| dtype = torch.float16
|
| else:
|
| dtype = torch.float32
|
|
|
| model = model.to(dtype)
|
|
|
|
|
| if use_flash_attention:
|
| model = _apply_flash_attention(model)
|
| print("Applied Flash Attention 2")
|
|
|
|
|
| if use_torch_compile:
|
| model = torch.compile(
|
| model,
|
| mode=compile_mode,
|
| fullgraph=True,
|
| dynamic=True,
|
| )
|
| print(f"Applied torch.compile with mode={compile_mode}")
|
|
|
|
|
| if quantization == "int8":
|
| model = _apply_int8_quantization(model)
|
| print("Applied INT8 quantization")
|
| elif quantization == "int4":
|
| model = _apply_int4_quantization(model)
|
| print("Applied INT4 quantization")
|
|
|
| return model
|
|
|
|
|
| def _apply_flash_attention(model: nn.Module) -> nn.Module:
|
| """
|
| Replace standard attention with Flash Attention 2.
|
| Requires: pip install flash-attn
|
| """
|
| try:
|
| from flash_attn import flash_attn_func
|
|
|
|
|
| for name, module in model.named_modules():
|
| if hasattr(module, 'use_flash_attention'):
|
| module.use_flash_attention = True
|
|
|
| original_forward = module.forward
|
|
|
| def flash_forward(self, x, *args, **kwargs):
|
| return self._flash_attention_forward(x, *args, **kwargs)
|
|
|
| module.forward = flash_forward.__get__(module, type(module))
|
|
|
| return model
|
|
|
| except ImportError:
|
| print("Flash Attention not available. Install with: pip install flash-attn")
|
| return model
|
|
|
|
|
| def _apply_int8_quantization(model: nn.Module) -> nn.Module:
|
| """
|
| Apply INT8 quantization using bitsandbytes.
|
| """
|
| try:
|
| import bitsandbytes as bnb
|
|
|
|
|
| for name, module in model.named_modules():
|
| if isinstance(module, nn.Linear):
|
|
|
| parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
|
| child_name = name.rsplit('.', 1)[1] if '.' in name else name
|
|
|
|
|
| parent = model
|
| if parent_name:
|
| for part in parent_name.split('.'):
|
| parent = getattr(parent, part)
|
|
|
|
|
| replacement = bnb.nn.Linear8bitLt(
|
| module.in_features,
|
| module.out_features,
|
| bias=module.bias is not None,
|
| has_fp16_weights=False,
|
| )
|
|
|
| replacement.weight.data = module.weight.data
|
| if module.bias is not None:
|
| replacement.bias.data = module.bias.data
|
|
|
| setattr(parent, child_name, replacement)
|
|
|
| return model
|
|
|
| except ImportError:
|
| print("bitsandbytes not available. Install with: pip install bitsandbytes")
|
| return model
|
|
|
|
|
| def _apply_int4_quantization(model: nn.Module) -> nn.Module:
|
| """
|
| Apply INT4 quantization using bitsandbytes.
|
| More aggressive, for 13B on 8GB VRAM.
|
| """
|
| try:
|
| import bitsandbytes as bnb
|
|
|
| for name, module in model.named_modules():
|
| if isinstance(module, nn.Linear):
|
| parent_name = name.rsplit('.', 1)[0] if '.' in name else ''
|
| child_name = name.rsplit('.', 1)[1] if '.' in name else name
|
|
|
| parent = model
|
| if parent_name:
|
| for part in parent_name.split('.'):
|
| parent = getattr(parent, part)
|
|
|
|
|
| replacement = bnb.nn.Linear4bit(
|
| module.in_features,
|
| module.out_features,
|
| bias=module.bias is not None,
|
| compute_dtype=torch.float16,
|
| compress_statistics=True,
|
| )
|
| replacement.weight.data = module.weight.data
|
| if module.bias is not None:
|
| replacement.bias.data = module.bias.data
|
|
|
| setattr(parent, child_name, replacement)
|
|
|
| return model
|
|
|
| except ImportError:
|
| print("bitsandbytes not available.")
|
| return model
|
|
|
|
|
| def get_cuda_memory_usage() -> Dict[str, float]:
|
| """Get current CUDA memory usage in GB."""
|
| if not torch.cuda.is_available():
|
| return {"error": "CUDA not available"}
|
|
|
| allocated = torch.cuda.memory_allocated() / 1e9
|
| reserved = torch.cuda.memory_reserved() / 1e9
|
| max_allocated = torch.cuda.max_memory_allocated() / 1e9
|
|
|
| return {
|
| "allocated_gb": allocated,
|
| "reserved_gb": reserved,
|
| "max_allocated_gb": max_allocated,
|
| }
|
|
|
|
|
| def profile_model(
|
| model: nn.Module,
|
| input_ids: torch.Tensor,
|
| num_warmup: int = 10,
|
| num_runs: int = 100,
|
| ) -> Dict[str, float]:
|
| """
|
| Profile model performance.
|
|
|
| Args:
|
| model: Model to profile
|
| input_ids: Example input
|
| num_warmup: Number of warmup runs
|
| num_runs: Number of profiling runs
|
|
|
| Returns:
|
| Dictionary with timing statistics
|
| """
|
| model.eval()
|
| device = next(model.parameters()).device
|
| input_ids = input_ids.to(device)
|
|
|
|
|
| with torch.no_grad():
|
| for _ in range(num_warmup):
|
| _ = model(input_ids)
|
|
|
|
|
| torch.cuda.synchronize()
|
| import time
|
| start = time.time()
|
|
|
| with torch.no_grad():
|
| for _ in range(num_runs):
|
| _ = model(input_ids)
|
|
|
| torch.cuda.synchronize()
|
| elapsed = time.time() - start
|
|
|
| avg_time = elapsed / num_runs
|
| tokens_per_sec = input_ids.shape[1] / avg_time
|
|
|
| return {
|
| "avg_time_sec": avg_time,
|
| "tokens_per_sec": tokens_per_sec,
|
| }
|
|
|
|
|
| def test_cuda_optimize():
|
| """Test CUDA optimizations."""
|
| if not torch.cuda.is_available():
|
| print("CUDA not available, skipping test")
|
| return
|
|
|
| from models.vortex_model import VortexModel
|
| from configs.vortex_7b_config import VORTEX_7B_CONFIG
|
|
|
| config = VORTEX_7B_CONFIG.copy()
|
| config["d_model"] = 512
|
| config["num_layers"] = 2
|
| config["num_heads"] = 8
|
| config["vocab_size"] = 1000
|
|
|
| model = VortexModel(config)
|
| print(f"Model parameters: {model.get_num_params():,}")
|
|
|
|
|
| model = optimize_for_cuda(
|
| model,
|
| config,
|
| use_flash_attention=False,
|
| use_torch_compile=False,
|
| quantization=None,
|
| )
|
|
|
|
|
| batch_size = 2
|
| seq_len = 128
|
| input_ids = torch.randint(0, config["vocab_size"], (batch_size, seq_len)).cuda()
|
|
|
| with torch.no_grad():
|
| output = model(input_ids)
|
| logits = output["logits"]
|
|
|
| print(f"Output shape: {logits.shape}")
|
| print("CUDA optimize test passed!")
|
|
|
|
|
| if __name__ == "__main__":
|
| test_cuda_optimize()
|
|
|