| | import logging |
| | import time |
| | from contextlib import contextmanager |
| |
|
| | import torch |
| |
|
| |
|
| | def vram_usage(): |
| | output = "" |
| | for i in range(torch.cuda.device_count()): |
| | gpu_memory_allocated = torch.cuda.memory_allocated(i) / ( |
| | 1024**3 |
| | ) |
| | gpu_memory_reserved = torch.cuda.memory_reserved(i) / ( |
| | 1024**3 |
| | ) |
| | output += f"GPU {i}: Memory Allocated: {gpu_memory_allocated:.2f} GB, Memory Reserved: {gpu_memory_reserved:.2f} GB" |
| | return output |
| |
|
| |
|
| | def ram_usage(): |
| | import psutil |
| |
|
| | process = psutil.Process() |
| | memory_info = process.memory_info() |
| | ram_used = memory_info.rss / (1024**3) |
| | return f"RAM Usage: {ram_used:.2f} GB" |
| |
|
| |
|
| | @contextmanager |
| | def resource_logger_context(logger: logging.Logger, task_description: str): |
| | """ |
| | Context manager to log the resource usage of the current task. |
| | Args: |
| | logger: The logger to use to log the resource usage. |
| | task_description: The description of the task to log. |
| | Returns: |
| | None |
| | """ |
| | try: |
| | initial_time = time.time() |
| | |
| | total_mem_bytes = torch.cuda.get_device_properties(0).total_memory |
| | initial_total_bytes = ( |
| | torch.cuda.memory_allocated(0) + torch.cuda.memory_reserved(0) |
| | ) |
| | torch.cuda.reset_peak_memory_stats(0) |
| | yield None |
| | finally: |
| | final_time = time.time() |
| | |
| | torch.cuda.synchronize() |
| |
|
| | |
| | final_allocated_bytes = torch.cuda.memory_allocated(0) |
| | final_reserved_bytes = torch.cuda.memory_reserved(0) |
| | final_total_bytes = final_allocated_bytes + final_reserved_bytes |
| |
|
| | delta_vram_percent_total = ( |
| | 100 * (final_total_bytes - initial_total_bytes) / total_mem_bytes |
| | if total_mem_bytes |
| | else 0.0 |
| | ) |
| | current_percent_vram_taken = ( |
| | 100 * final_total_bytes / total_mem_bytes if total_mem_bytes else 0.0 |
| | ) |
| | block_peak_percent = ( |
| | 100 * torch.cuda.max_memory_allocated(0) / total_mem_bytes |
| | if total_mem_bytes |
| | else 0.0 |
| | ) |
| | delta_time_str = time.strftime( |
| | '%H:%M:%S', time.gmtime(final_time - initial_time) |
| | ) |
| |
|
| | logger.info( |
| | f"For task: {task_description}, ΔVRAM % (total): {delta_vram_percent_total:.2f}%, Current % of VRAM taken: {current_percent_vram_taken:.2f}%, Block Peak % of device VRAM: {block_peak_percent:.2f}%, ΔTime: {delta_time_str}" |
| | ) |
| |
|