neural-mesh / test /utils /vllm_sharing_rollout.py
hjkim00's picture
Upload TestTime-RLVR-v2 from Full-pipeline-relative_0827 branch
f50dc54 verified
"""
VLLM 공유를 위한 Custom Rollout Worker
Step 1-4의 Ray Actor VLLM을 Step 5에서 재사용
"""
import ray
from typing import Optional, Any, Dict, List
from verl.workers.rollout.vllm_rollout import vLLMRollout
from verl.protocol import DataProto
import torch
class SharedVLLMRollout(vLLMRollout):
"""기존 Ray Actor의 VLLM을 재사용하는 Rollout Worker"""
def __init__(self,
actor_handle: Optional[Any] = None,
*args, **kwargs):
"""
Args:
actor_handle: Step 1-4의 RemoteTestTimePipeline Ray Actor 참조
"""
self.existing_vllm_actor = actor_handle
if self.existing_vllm_actor is not None:
print(f"🔗 Using existing VLLM Actor: {self.existing_vllm_actor}")
# VLLM 엔진 생성을 스킵하고 Actor 참조만 저장
self.use_external_vllm = True
# 부모 클래스의 __init__을 호출하되, 모델 로딩은 스킵
kwargs['skip_model_loading'] = True
else:
print("⚠️ No existing VLLM Actor provided, creating new VLLM")
self.use_external_vllm = False
super().__init__(*args, **kwargs)
def generate(self,
prompts: List[str],
sampling_params: Dict[str, Any],
**kwargs) -> DataProto:
"""
텍스트 생성 - 기존 VLLM Actor 사용
"""
if self.use_external_vllm and self.existing_vllm_actor is not None:
# Ray Actor의 generate 메서드 호출
print(f"📡 Calling remote VLLM Actor for {len(prompts)} prompts")
# RemoteTestTimePipeline의 generate 메서드 호출
# 이 부분은 RemoteTestTimePipeline의 인터페이스에 맞게 조정 필요
result = ray.get(
self.existing_vllm_actor.generate_batch_vllm.remote(
prompts=prompts,
max_tokens=sampling_params.get('max_tokens', 512),
temperature=sampling_params.get('temperature', 0.7),
top_p=sampling_params.get('top_p', 1.0),
n=sampling_params.get('n', 1)
)
)
# 결과를 DataProto 형식으로 변환
return self._convert_to_dataproto(result)
else:
# 기존 VLLM 사용 (부모 클래스 메서드)
return super().generate(prompts, sampling_params, **kwargs)
def _convert_to_dataproto(self, result: Dict[str, Any]) -> DataProto:
"""Ray Actor 결과를 DataProto로 변환"""
# RemoteTestTimePipeline의 출력 형식에 맞게 조정 필요
# 예시 구현:
responses = result.get('responses', [])
data_dict = {
'responses': responses,
'input_ids': result.get('input_ids', []),
'attention_mask': result.get('attention_mask', [])
}
return DataProto.from_single_dict(data_dict)
def update_weight(self, state_dict: Dict[str, torch.Tensor]):
"""
모델 가중치 업데이트 - VLLM 공유 시에는 스킵
"""
if self.use_external_vllm:
print("🔄 Skipping weight update for shared VLLM (handled by Ray Actor)")
return
else:
super().update_weight(state_dict)