| |
| |
|
|
| import contextlib |
| import logging |
| import os |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
| import torch.distributed |
| import wandb |
| import xformers.profiler |
| from torch.profiler.profiler import profile |
| from xformers.profiler import MemSnapshotsProfiler, PyTorchProfiler |
|
|
| from core.distributed import get_is_master |
|
|
|
|
| @dataclass |
| class ProfilerArgs: |
| run: bool = False |
| trace_folder: str = "profiling" |
| mem_warmup: int = 100 |
| mem_steps: int = 2 |
| profile_warmup: int = 102 |
| profile_steps: int = 2 |
|
|
|
|
| logger = logging.getLogger() |
|
|
|
|
| def perfetto_to_html(json_file, html_file): |
| import gzip |
| import string |
|
|
| import viztracer |
|
|
| root = os.path.dirname(viztracer.__file__) |
| sub = {} |
| json_file = gzip.open(json_file) if ".gz" in str(json_file) else open(json_file) |
| with open( |
| os.path.join(root, "html/trace_viewer_embedder.html"), encoding="utf-8" |
| ) as f: |
| tmpl = f.read() |
| with open(os.path.join(root, "html/trace_viewer_full.html"), encoding="utf-8") as f: |
| sub["trace_viewer_full"] = f.read() |
| with json_file as j: |
| content = j.read() |
| if isinstance(content, bytes): |
| content = content.decode("utf-8") |
| sub["json_data"] = content.replace("</script>", "<\\/script>") |
| with open(html_file, "w+", encoding="utf-8") as output_file: |
| output_file.write(string.Template(tmpl).substitute(sub)) |
|
|
|
|
| class PyTorchProfilerWandb(PyTorchProfiler): |
| def __init__(self, main_profiler) -> None: |
| self.main_profiler = main_profiler |
| self.num_steps = 0 |
| self.pytorch_profiler = torch.profiler.profile( |
| on_trace_ready=self._on_trace, |
| profile_memory=True, |
| record_shapes=True, |
| |
| |
| |
| with_stack=False, |
| with_flops=True, |
| activities=self.ACTIVITIES, |
| ) |
|
|
| def _analyze_trace(self, prof: profile): |
| logger.info("Begin analyze trace") |
| super()._analyze_trace(prof) |
| logger.info("End analyze trace") |
|
|
| def _on_trace(self, prof: torch.profiler.profiler.profile) -> None: |
| super()._on_trace(prof) |
| if get_is_master() and wandb.run is not None: |
| filename = list( |
| Path(self.main_profiler.output_dir).glob( |
| "profile_CPU_CUDA*/*.pt.trace.json*" |
| ) |
| )[0] |
| html_path = str(filename).replace(".json", ".html") |
| perfetto_to_html(filename, html_path) |
| wandb.log({"profile_trace": wandb.Html(html_path)}) |
|
|
|
|
| class MemSnapshotsProfilerWandb(MemSnapshotsProfiler): |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| super().__exit__(exc_type, exc_val, exc_tb) |
| if get_is_master() and wandb.run is not None: |
| filename = list( |
| Path(self.main_profiler.output_dir).glob("memory_trace_plot/*.html") |
| )[0] |
| wandb.log({"memory_trace": wandb.Html(open(filename), inject=False)}) |
|
|
|
|
| @contextlib.contextmanager |
| def maybe_run_profiler(dump_dir, module, config: ProfilerArgs): |
| |
|
|
| if config.run: |
| trace_dir = os.path.join(dump_dir, config.trace_folder) |
|
|
| logger.info(f"Profiling active. Traces will be saved at {trace_dir}") |
|
|
| if get_is_master() and not os.path.exists(trace_dir): |
| os.makedirs(trace_dir) |
| if torch.distributed.is_initialized(): |
| torch.distributed.barrier() |
|
|
| with xformers.profiler.profile( |
| output_dir=trace_dir, |
| module=module, |
| schedule=[ |
| ( |
| MemSnapshotsProfilerWandb, |
| config.mem_warmup, |
| config.mem_warmup + config.mem_steps, |
| ), |
| ( |
| PyTorchProfilerWandb, |
| config.profile_warmup, |
| config.profile_warmup + config.profile_steps, |
| ), |
| ], |
| ) as profiler: |
| yield profiler |
|
|
| else: |
| torch_profiler = contextlib.nullcontext() |
| yield None |
|
|