""" VLLM 공유를 위한 Custom Rollout Worker Step 1-4의 Ray Actor VLLM을 Step 5에서 재사용 """ import ray from typing import Optional, Any, Dict, List from verl.workers.rollout.vllm_rollout import vLLMRollout from verl.protocol import DataProto import torch class SharedVLLMRollout(vLLMRollout): """기존 Ray Actor의 VLLM을 재사용하는 Rollout Worker""" def __init__(self, actor_handle: Optional[Any] = None, *args, **kwargs): """ Args: actor_handle: Step 1-4의 RemoteTestTimePipeline Ray Actor 참조 """ self.existing_vllm_actor = actor_handle if self.existing_vllm_actor is not None: print(f"🔗 Using existing VLLM Actor: {self.existing_vllm_actor}") # VLLM 엔진 생성을 스킵하고 Actor 참조만 저장 self.use_external_vllm = True # 부모 클래스의 __init__을 호출하되, 모델 로딩은 스킵 kwargs['skip_model_loading'] = True else: print("⚠️ No existing VLLM Actor provided, creating new VLLM") self.use_external_vllm = False super().__init__(*args, **kwargs) def generate(self, prompts: List[str], sampling_params: Dict[str, Any], **kwargs) -> DataProto: """ 텍스트 생성 - 기존 VLLM Actor 사용 """ if self.use_external_vllm and self.existing_vllm_actor is not None: # Ray Actor의 generate 메서드 호출 print(f"📡 Calling remote VLLM Actor for {len(prompts)} prompts") # RemoteTestTimePipeline의 generate 메서드 호출 # 이 부분은 RemoteTestTimePipeline의 인터페이스에 맞게 조정 필요 result = ray.get( self.existing_vllm_actor.generate_batch_vllm.remote( prompts=prompts, max_tokens=sampling_params.get('max_tokens', 512), temperature=sampling_params.get('temperature', 0.7), top_p=sampling_params.get('top_p', 1.0), n=sampling_params.get('n', 1) ) ) # 결과를 DataProto 형식으로 변환 return self._convert_to_dataproto(result) else: # 기존 VLLM 사용 (부모 클래스 메서드) return super().generate(prompts, sampling_params, **kwargs) def _convert_to_dataproto(self, result: Dict[str, Any]) -> DataProto: """Ray Actor 결과를 DataProto로 변환""" # RemoteTestTimePipeline의 출력 형식에 맞게 조정 필요 # 예시 구현: responses = result.get('responses', []) data_dict = { 'responses': responses, 'input_ids': result.get('input_ids', []), 'attention_mask': result.get('attention_mask', []) } return DataProto.from_single_dict(data_dict) def update_weight(self, state_dict: Dict[str, torch.Tensor]): """ 모델 가중치 업데이트 - VLLM 공유 시에는 스킵 """ if self.use_external_vllm: print("🔄 Skipping weight update for shared VLLM (handled by Ray Actor)") return else: super().update_weight(state_dict)