Spaces:
Sleeping
Sleeping
| """ | |
| Model loading and caching for FlashAttention Explorer. | |
| Uses real HuggingFace models with SDPA attention implementation. | |
| """ | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from typing import Tuple, Optional | |
| import os | |
| from .constants import MODEL_CONFIGS | |
| # Global cache to avoid reloading models | |
| _model_cache: dict = {} | |
| _tokenizer_cache: dict = {} | |
| def get_device() -> str: | |
| """Get the appropriate device (cuda if available, else cpu).""" | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model(model_name: str, force_reload: bool = False) -> AutoModelForCausalLM: | |
| """ | |
| Load a model with caching to avoid redundant downloads. | |
| Args: | |
| model_name: Key from MODEL_CONFIGS (e.g., "SmolLM2-360M") | |
| force_reload: If True, reload even if cached | |
| Returns: | |
| Loaded model on appropriate device | |
| """ | |
| if model_name not in MODEL_CONFIGS: | |
| raise ValueError(f"Unknown model: {model_name}. Available: {list(MODEL_CONFIGS.keys())}") | |
| if model_name in _model_cache and not force_reload: | |
| return _model_cache[model_name] | |
| config = MODEL_CONFIGS[model_name] | |
| model_id = config["model_id"] | |
| # Check if we need token for gated models (Llama) | |
| token = os.environ.get("HF_TOKEN", None) | |
| # Load model with SDPA attention for backend switching | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.float16, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| attn_implementation="sdpa", # Enable PyTorch SDPA backends | |
| token=token, | |
| trust_remote_code=True, | |
| ) | |
| # Move to device if not using device_map | |
| if not torch.cuda.is_available(): | |
| model = model.to("cpu") | |
| model.eval() | |
| _model_cache[model_name] = model | |
| return model | |
| def load_tokenizer(model_name: str) -> AutoTokenizer: | |
| """ | |
| Load tokenizer with caching. | |
| Args: | |
| model_name: Key from MODEL_CONFIGS | |
| Returns: | |
| Loaded tokenizer | |
| """ | |
| if model_name in _tokenizer_cache: | |
| return _tokenizer_cache[model_name] | |
| config = MODEL_CONFIGS[model_name] | |
| model_id = config["model_id"] | |
| token = os.environ.get("HF_TOKEN", None) | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| token=token, | |
| trust_remote_code=True, | |
| ) | |
| # Ensure padding token exists | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| _tokenizer_cache[model_name] = tokenizer | |
| return tokenizer | |
| def load_model_and_tokenizer( | |
| model_name: str | |
| ) -> Tuple[AutoModelForCausalLM, AutoTokenizer]: | |
| """ | |
| Load both model and tokenizer. | |
| Args: | |
| model_name: Key from MODEL_CONFIGS | |
| Returns: | |
| Tuple of (model, tokenizer) | |
| """ | |
| model = load_model(model_name) | |
| tokenizer = load_tokenizer(model_name) | |
| return model, tokenizer | |
| def get_model_memory_footprint(model_name: str) -> dict: | |
| """ | |
| Calculate theoretical memory footprint for a model. | |
| Args: | |
| model_name: Key from MODEL_CONFIGS | |
| Returns: | |
| Dict with memory breakdown in GB | |
| """ | |
| config = MODEL_CONFIGS[model_name] | |
| # Approximate parameter count | |
| # Embedding: vocab_size * hidden_dim | |
| # Attention per layer: 4 * hidden_dim^2 (Q, K, V, O projections) | |
| # FFN per layer: ~8 * hidden_dim^2 (typical 4x expansion) | |
| # LM head: vocab_size * hidden_dim | |
| hidden = config["hidden_dim"] | |
| layers = config["layers"] | |
| vocab = config["vocab_size"] | |
| embedding_params = vocab * hidden | |
| attention_params = 4 * hidden * hidden * layers | |
| ffn_params = 8 * hidden * hidden * layers | |
| lm_head_params = vocab * hidden | |
| total_params = embedding_params + attention_params + ffn_params + lm_head_params | |
| # FP16 = 2 bytes per parameter | |
| memory_gb = (total_params * 2) / (1024 ** 3) | |
| return { | |
| "total_params_millions": total_params / 1e6, | |
| "model_memory_gb": memory_gb, | |
| "breakdown": { | |
| "embeddings_gb": (embedding_params * 2) / (1024 ** 3), | |
| "attention_gb": (attention_params * 2) / (1024 ** 3), | |
| "ffn_gb": (ffn_params * 2) / (1024 ** 3), | |
| "lm_head_gb": (lm_head_params * 2) / (1024 ** 3), | |
| } | |
| } | |
| def calculate_kv_cache_size( | |
| model_name: str, | |
| seq_len: int, | |
| batch_size: int = 1, | |
| dtype_bytes: int = 2 # FP16 | |
| ) -> dict: | |
| """ | |
| Calculate KV cache memory for given sequence length. | |
| Args: | |
| model_name: Key from MODEL_CONFIGS | |
| seq_len: Sequence length | |
| batch_size: Batch size | |
| dtype_bytes: Bytes per element (2 for FP16, 4 for FP32) | |
| Returns: | |
| Dict with KV cache size information | |
| """ | |
| config = MODEL_CONFIGS[model_name] | |
| layers = config["layers"] | |
| kv_heads = config["kv_heads"] | |
| head_dim = config["head_dim"] | |
| # KV cache size: 2 (K and V) * layers * kv_heads * seq_len * head_dim * batch_size * dtype_bytes | |
| kv_cache_bytes = 2 * layers * kv_heads * seq_len * head_dim * batch_size * dtype_bytes | |
| kv_cache_gb = kv_cache_bytes / (1024 ** 3) | |
| # Calculate what it would be with MHA (all heads have own KV) | |
| q_heads = config["q_heads"] | |
| mha_cache_bytes = 2 * layers * q_heads * seq_len * head_dim * batch_size * dtype_bytes | |
| mha_cache_gb = mha_cache_bytes / (1024 ** 3) | |
| return { | |
| "gqa_cache_gb": kv_cache_gb, | |
| "mha_cache_gb": mha_cache_gb, | |
| "savings_ratio": q_heads / kv_heads, | |
| "savings_gb": mha_cache_gb - kv_cache_gb, | |
| } | |
| def clear_model_cache(): | |
| """Clear all cached models to free memory.""" | |
| global _model_cache, _tokenizer_cache | |
| _model_cache.clear() | |
| _tokenizer_cache.clear() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def get_available_models() -> list: | |
| """Return list of available model names.""" | |
| return list(MODEL_CONFIGS.keys()) | |