|
""" |
|
VLLM 공유를 위한 Custom Rollout Worker |
|
Step 1-4의 Ray Actor VLLM을 Step 5에서 재사용 |
|
""" |
|
import ray |
|
from typing import Optional, Any, Dict, List |
|
from verl.workers.rollout.vllm_rollout import vLLMRollout |
|
from verl.protocol import DataProto |
|
import torch |
|
|
|
|
|
class SharedVLLMRollout(vLLMRollout): |
|
"""기존 Ray Actor의 VLLM을 재사용하는 Rollout Worker""" |
|
|
|
def __init__(self, |
|
actor_handle: Optional[Any] = None, |
|
*args, **kwargs): |
|
""" |
|
Args: |
|
actor_handle: Step 1-4의 RemoteTestTimePipeline Ray Actor 참조 |
|
""" |
|
self.existing_vllm_actor = actor_handle |
|
|
|
if self.existing_vllm_actor is not None: |
|
print(f"🔗 Using existing VLLM Actor: {self.existing_vllm_actor}") |
|
|
|
self.use_external_vllm = True |
|
|
|
kwargs['skip_model_loading'] = True |
|
else: |
|
print("⚠️ No existing VLLM Actor provided, creating new VLLM") |
|
self.use_external_vllm = False |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
def generate(self, |
|
prompts: List[str], |
|
sampling_params: Dict[str, Any], |
|
**kwargs) -> DataProto: |
|
""" |
|
텍스트 생성 - 기존 VLLM Actor 사용 |
|
""" |
|
if self.use_external_vllm and self.existing_vllm_actor is not None: |
|
|
|
print(f"📡 Calling remote VLLM Actor for {len(prompts)} prompts") |
|
|
|
|
|
|
|
result = ray.get( |
|
self.existing_vllm_actor.generate_batch_vllm.remote( |
|
prompts=prompts, |
|
max_tokens=sampling_params.get('max_tokens', 512), |
|
temperature=sampling_params.get('temperature', 0.7), |
|
top_p=sampling_params.get('top_p', 1.0), |
|
n=sampling_params.get('n', 1) |
|
) |
|
) |
|
|
|
|
|
return self._convert_to_dataproto(result) |
|
else: |
|
|
|
return super().generate(prompts, sampling_params, **kwargs) |
|
|
|
def _convert_to_dataproto(self, result: Dict[str, Any]) -> DataProto: |
|
"""Ray Actor 결과를 DataProto로 변환""" |
|
|
|
|
|
responses = result.get('responses', []) |
|
|
|
data_dict = { |
|
'responses': responses, |
|
'input_ids': result.get('input_ids', []), |
|
'attention_mask': result.get('attention_mask', []) |
|
} |
|
|
|
return DataProto.from_single_dict(data_dict) |
|
|
|
def update_weight(self, state_dict: Dict[str, torch.Tensor]): |
|
""" |
|
모델 가중치 업데이트 - VLLM 공유 시에는 스킵 |
|
""" |
|
if self.use_external_vllm: |
|
print("🔄 Skipping weight update for shared VLLM (handled by Ray Actor)") |
|
return |
|
else: |
|
super().update_weight(state_dict) |