""" No-sync sharding manager for TTRLVR that disables weight synchronization like original AZR. """ import logging import os from torch.distributed.device_mesh import DeviceMesh from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage from verl.utils.device import get_torch_device logger = logging.getLogger(__file__) logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) class NoSyncFSDPVLLMShardingManager(FSDPVLLMShardingManager): """ A custom sharding manager that disables weight synchronization between FSDP and VLLM. This mimics the behavior of original AZR where VLLM weights are not updated during training. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.sync_weights = False # Disable weight sync by default logger.info("🚫 NoSyncFSDPVLLMShardingManager initialized - weight sync disabled") @GPUMemoryLogger(role="no_sync_fsdp_vllm_sharding_manager", logger=logger) def __enter__(self): """ Enter the sharding manager context without syncing weights. This keeps VLLM using the initial weights throughout the epoch. """ # Just empty cache and set random states get_torch_device().empty_cache() log_gpu_memory_usage("After empty_cache in no-sync sharding manager", logger=logger) # Important: need to manually set the random states of each tp to be identical if self.device_mesh is not None: self.torch_random_states = get_torch_device().get_rng_state() get_torch_device().set_rng_state(self.gen_random_states) logger.info("✅ Entered no-sync sharding manager - skipping weight synchronization") @GPUMemoryLogger(role="no_sync_fsdp_vllm_sharding_manager", logger=logger) def __exit__(self, exc_type, exc_value, traceback): """ Exit the sharding manager context. """ # Set module back to train mode self.module.train() # Empty cache after compute get_torch_device().empty_cache() # Restore random states if self.device_mesh is not None: self.gen_random_states = get_torch_device().get_rng_state() get_torch_device().set_rng_state(self.torch_random_states) logger.info("✅ Exited no-sync sharding manager")