Spaces:
Running
Running
File size: 10,028 Bytes
80a1334 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 |
import requests
from huggingface_hub import HfApi
from typing import Dict, Optional, Tuple
import json
class ModelMemoryCalculator:
def __init__(self):
self.hf_api = HfApi()
self.cache = {} # Cache results to avoid repeated API calls
def get_model_memory_requirements(self, model_id: str) -> Dict:
"""
Calculate memory requirements for a given HuggingFace model.
Args:
model_id: HuggingFace model identifier (e.g., "black-forest-labs/FLUX.1-schnell")
Returns:
Dict with memory information including:
- total_params: Total parameter count
- memory_fp32: Memory in GB at FP32 precision
- memory_fp16: Memory in GB at FP16 precision
- memory_bf16: Memory in GB at BF16 precision
- safetensors_files: List of safetensor files and their sizes
"""
if model_id in self.cache:
return self.cache[model_id]
try:
print(f"Fetching model info for {model_id}...")
# Get model info
model_info = self.hf_api.model_info(model_id)
print(f"Model info retrieved successfully")
# Get safetensors metadata
print(f"Fetching safetensors metadata...")
safetensors_metadata = self.hf_api.get_safetensors_metadata(model_id)
print(f"Found {len(safetensors_metadata)} safetensor files")
total_params = 0
safetensors_files = []
# Iterate through all safetensor files
for filename, metadata in safetensors_metadata.items():
file_params = 0
file_size_bytes = 0
# Calculate parameters from tensor metadata
if 'metadata' in metadata and metadata['metadata']:
for tensor_name, tensor_info in metadata['metadata'].items():
if 'shape' in tensor_info and 'dtype' in tensor_info:
# Calculate tensor size
shape = tensor_info['shape']
tensor_params = 1
for dim in shape:
tensor_params *= dim
file_params += tensor_params
# Calculate byte size based on dtype
dtype = tensor_info['dtype']
bytes_per_param = self._get_bytes_per_param(dtype)
file_size_bytes += tensor_params * bytes_per_param
total_params += file_params
safetensors_files.append({
'filename': filename,
'parameters': file_params,
'size_bytes': file_size_bytes,
'size_mb': file_size_bytes / (1024 * 1024)
})
# Calculate memory requirements for different precisions
memory_requirements = {
'model_id': model_id,
'total_params': total_params,
'total_params_billions': total_params / 1e9,
'memory_fp32_gb': (total_params * 4) / (1024**3), # 4 bytes per param
'memory_fp16_gb': (total_params * 2) / (1024**3), # 2 bytes per param
'memory_bf16_gb': (total_params * 2) / (1024**3), # 2 bytes per param
'memory_int8_gb': (total_params * 1) / (1024**3), # 1 byte per param
'safetensors_files': safetensors_files,
'estimated_inference_memory_fp16_gb': self._estimate_inference_memory(total_params, 'fp16'),
'estimated_inference_memory_bf16_gb': self._estimate_inference_memory(total_params, 'bf16'),
}
# Cache the result
self.cache[model_id] = memory_requirements
return memory_requirements
except Exception as e:
return {
'error': str(e),
'model_id': model_id,
'total_params': 0,
'memory_fp32_gb': 0,
'memory_fp16_gb': 0,
'memory_bf16_gb': 0,
}
def _get_bytes_per_param(self, dtype: str) -> int:
"""Get bytes per parameter for different data types."""
dtype_map = {
'F32': 4, 'float32': 4,
'F16': 2, 'float16': 2,
'BF16': 2, 'bfloat16': 2,
'I8': 1, 'int8': 1,
'I32': 4, 'int32': 4,
'I64': 8, 'int64': 8,
}
return dtype_map.get(dtype, 4) # Default to 4 bytes (FP32)
def _estimate_inference_memory(self, total_params: int, precision: str) -> float:
"""
Estimate memory requirements during inference.
This includes model weights + activations + intermediate tensors.
"""
bytes_per_param = 2 if precision in ['fp16', 'bf16'] else 4
# Model weights
model_memory = (total_params * bytes_per_param) / (1024**3)
# Estimate activation memory (rough approximation)
# For diffusion models, activations can be 1.5-3x model size during inference
activation_multiplier = 2.0
total_inference_memory = model_memory * (1 + activation_multiplier)
return total_inference_memory
def get_memory_recommendation(self, model_id: str, available_vram_gb: float) -> Dict:
"""
Get memory recommendations based on available VRAM.
Args:
model_id: HuggingFace model identifier
available_vram_gb: Available VRAM in GB
Returns:
Dict with recommendations for precision, offloading, etc.
"""
memory_info = self.get_model_memory_requirements(model_id)
if 'error' in memory_info:
return {'error': memory_info['error']}
recommendations = {
'model_id': model_id,
'available_vram_gb': available_vram_gb,
'model_memory_fp16_gb': memory_info['memory_fp16_gb'],
'estimated_inference_memory_fp16_gb': memory_info['estimated_inference_memory_fp16_gb'],
'recommendations': []
}
inference_memory_fp16 = memory_info['estimated_inference_memory_fp16_gb']
inference_memory_bf16 = memory_info['estimated_inference_memory_bf16_gb']
# Determine recommendations
if available_vram_gb >= inference_memory_bf16:
recommendations['recommendations'].append("β
Full model can fit in VRAM with BF16 precision")
recommendations['recommended_precision'] = 'bfloat16'
recommendations['cpu_offload'] = False
recommendations['attention_slicing'] = False
elif available_vram_gb >= inference_memory_fp16:
recommendations['recommendations'].append("β
Full model can fit in VRAM with FP16 precision")
recommendations['recommended_precision'] = 'float16'
recommendations['cpu_offload'] = False
recommendations['attention_slicing'] = False
elif available_vram_gb >= memory_info['memory_fp16_gb']:
recommendations['recommendations'].append("β οΈ Model weights fit, but may need memory optimizations")
recommendations['recommended_precision'] = 'float16'
recommendations['cpu_offload'] = False
recommendations['attention_slicing'] = True
recommendations['vae_slicing'] = True
else:
recommendations['recommendations'].append("π Requires CPU offloading and memory optimizations")
recommendations['recommended_precision'] = 'float16'
recommendations['cpu_offload'] = True
recommendations['sequential_offload'] = True
recommendations['attention_slicing'] = True
recommendations['vae_slicing'] = True
return recommendations
def format_memory_info(self, model_id: str) -> str:
"""Format memory information for display."""
info = self.get_model_memory_requirements(model_id)
if 'error' in info:
return f"β Error calculating memory for {model_id}: {info['error']}"
output = f"""
π **Memory Requirements for {model_id}**
π’ **Parameters**: {info['total_params_billions']:.2f}B parameters
πΎ **Model Memory**:
β’ FP32: {info['memory_fp32_gb']:.2f} GB
β’ FP16/BF16: {info['memory_fp16_gb']:.2f} GB
β’ INT8: {info['memory_int8_gb']:.2f} GB
π **Estimated Inference Memory**:
β’ FP16: {info['estimated_inference_memory_fp16_gb']:.2f} GB
β’ BF16: {info['estimated_inference_memory_bf16_gb']:.2f} GB
π **SafeTensor Files**: {len(info['safetensors_files'])} files
"""
return output.strip()
# Example usage and testing
if __name__ == "__main__":
calculator = ModelMemoryCalculator()
# Test with FLUX.1-schnell
model_id = "black-forest-labs/FLUX.1-schnell"
print(f"Testing memory calculation for {model_id}...")
memory_info = calculator.get_model_memory_requirements(model_id)
print(json.dumps(memory_info, indent=2))
# Test recommendations
print("\n" + "="*50)
print("MEMORY RECOMMENDATIONS")
print("="*50)
vram_options = [8, 16, 24, 40]
for vram in vram_options:
rec = calculator.get_memory_recommendation(model_id, vram)
print(f"\nπ― For {vram}GB VRAM:")
if 'recommendations' in rec:
for r in rec['recommendations']:
print(f" {r}")
# Format for display
print("\n" + "="*50)
print("FORMATTED OUTPUT")
print("="*50)
print(calculator.format_memory_info(model_id)) |