| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import unittest |
| from unittest.mock import MagicMock, patch |
|
|
| from verl.utils import omega_conf_to_dataclass |
| from verl.utils.profiler.config import NsightToolConfig, ProfilerConfig |
| from verl.utils.profiler.nvtx_profile import NsightSystemsProfiler |
|
|
|
|
| class TestProfilerConfig(unittest.TestCase): |
| def test_config_init(self): |
| import os |
|
|
| from hydra import compose, initialize_config_dir |
|
|
| with initialize_config_dir(config_dir=os.path.abspath("verl/trainer/config")): |
| cfg = compose(config_name="ppo_trainer") |
| for config in [ |
| cfg.actor_rollout_ref.actor.profiler, |
| cfg.actor_rollout_ref.rollout.profiler, |
| cfg.actor_rollout_ref.ref.profiler, |
| cfg.critic.profiler, |
| cfg.reward_model.profiler, |
| ]: |
| profiler_config = omega_conf_to_dataclass(config) |
| self.assertEqual(profiler_config.tool, config.tool) |
| self.assertEqual(profiler_config.enable, config.enable) |
| self.assertEqual(profiler_config.all_ranks, config.all_ranks) |
| self.assertEqual(profiler_config.ranks, config.ranks) |
| self.assertEqual(profiler_config.save_path, config.save_path) |
| self.assertEqual(profiler_config.ranks, config.ranks) |
| assert isinstance(profiler_config, ProfilerConfig) |
| with self.assertRaises(AttributeError): |
| _ = profiler_config.non_existing_key |
| assert config.get("non_existing_key") == profiler_config.get("non_existing_key") |
| assert config.get("non_existing_key", 1) == profiler_config.get("non_existing_key", 1) |
|
|
| def test_frozen_config(self): |
| """Test that modifying frozen keys in ProfilerConfig raises exceptions.""" |
| from dataclasses import FrozenInstanceError |
|
|
| from verl.utils.profiler.config import ProfilerConfig |
|
|
| |
| config = ProfilerConfig(all_ranks=False, ranks=[0]) |
|
|
| with self.assertRaises(FrozenInstanceError): |
| config.all_ranks = True |
|
|
| with self.assertRaises(FrozenInstanceError): |
| config.ranks = [1, 2, 3] |
|
|
| with self.assertRaises(TypeError): |
| config["all_ranks"] = True |
|
|
| with self.assertRaises(TypeError): |
| config["ranks"] = [1, 2, 3] |
|
|
|
|
| class TestNsightSystemsProfiler(unittest.TestCase): |
| """Test suite for NsightSystemsProfiler functionality. |
| |
| Test Plan: |
| 1. Initialization: Verify profiler state after creation |
| 2. Basic Profiling: Test start/stop functionality |
| 3. Discrete Mode: TODO: Test discrete profiling behavior |
| 4. Annotation: Test the annotate decorator in both normal and discrete modes |
| 5. Config Validation: Verify proper config initialization from OmegaConf |
| """ |
|
|
| def setUp(self): |
| self.config = ProfilerConfig(enable=True, all_ranks=True) |
| self.rank = 0 |
| self.profiler = NsightSystemsProfiler(self.rank, self.config, tool_config=NsightToolConfig(discrete=False)) |
|
|
| def test_initialization(self): |
| self.assertEqual(self.profiler.this_rank, True) |
| self.assertEqual(self.profiler.this_step, False) |
|
|
| def test_start_stop_profiling(self): |
| with patch("torch.cuda.profiler.start") as mock_start, patch("torch.cuda.profiler.stop") as mock_stop: |
| |
| self.profiler.start() |
| self.assertTrue(self.profiler.this_step) |
| mock_start.assert_called_once() |
|
|
| |
| self.profiler.stop() |
| self.assertFalse(self.profiler.this_step) |
| mock_stop.assert_called_once() |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| def test_annotate_decorator(self): |
| mock_self = MagicMock() |
| mock_self.profiler = self.profiler |
| mock_self.profiler.this_step = True |
| decorator = mock_self.profiler.annotate(message="test") |
|
|
| @decorator |
| def test_func(self, *args, **kwargs): |
| return "result" |
|
|
| with ( |
| patch("torch.cuda.profiler.start") as mock_start, |
| patch("torch.cuda.profiler.stop") as mock_stop, |
| patch("verl.utils.profiler.nvtx_profile.mark_start_range") as mock_start_range, |
| patch("verl.utils.profiler.nvtx_profile.mark_end_range") as mock_end_range, |
| ): |
| result = test_func(mock_self) |
| self.assertEqual(result, "result") |
| mock_start_range.assert_called_once() |
| mock_end_range.assert_called_once() |
| mock_start.assert_not_called() |
| mock_stop.assert_not_called() |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|