neural-mesh / test /utils /no_sync_sharding_manager.py
hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
"""
No-sync sharding manager for TTRLVR that disables weight synchronization like original AZR.
"""
import logging
import os
from torch.distributed.device_mesh import DeviceMesh
from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager
from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage
from verl.utils.device import get_torch_device
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class NoSyncFSDPVLLMShardingManager(FSDPVLLMShardingManager):
"""
A custom sharding manager that disables weight synchronization between FSDP and VLLM.
This mimics the behavior of original AZR where VLLM weights are not updated during training.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sync_weights = False # Disable weight sync by default
logger.info("🚫 NoSyncFSDPVLLMShardingManager initialized - weight sync disabled")
@GPUMemoryLogger(role="no_sync_fsdp_vllm_sharding_manager", logger=logger)
def __enter__(self):
"""
Enter the sharding manager context without syncing weights.
This keeps VLLM using the initial weights throughout the epoch.
"""
# Just empty cache and set random states
get_torch_device().empty_cache()
log_gpu_memory_usage("After empty_cache in no-sync sharding manager", logger=logger)
# Important: need to manually set the random states of each tp to be identical
if self.device_mesh is not None:
self.torch_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.gen_random_states)
logger.info("βœ… Entered no-sync sharding manager - skipping weight synchronization")
@GPUMemoryLogger(role="no_sync_fsdp_vllm_sharding_manager", logger=logger)
def __exit__(self, exc_type, exc_value, traceback):
"""
Exit the sharding manager context.
"""
# Set module back to train mode
self.module.train()
# Empty cache after compute
get_torch_device().empty_cache()
# Restore random states
if self.device_mesh is not None:
self.gen_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.torch_random_states)
logger.info("βœ… Exited no-sync sharding manager")