| |
|
| | from typing import Dict |
| | from dataclasses import dataclass, field |
| | from abc import ABC, abstractmethod |
| |
|
| | from larm.data.utils.tensor_utils import TensorHelper, TensorConfig |
| |
|
| |
|
| | @dataclass |
| | class InteractionConfig: |
| | max_turns: int |
| | max_start_length: int |
| | max_prompt_length: int |
| | max_response_length: int |
| | max_obs_length: int |
| | do_sample: bool |
| | temperature: float |
| |
|
| | @dataclass |
| | class InteractionDataProto: |
| | batch: Dict = field(default_factory=dict) |
| | no_tensor_batch: Dict = field(default_factory=dict) |
| |
|
| | class InteractionManager(ABC): |
| | |
| | def __init__( |
| | self, |
| | tokenizer, |
| | actor_rollout_wg, |
| | config: InteractionConfig, |
| | is_validation: bool = False, |
| | ): |
| | tokenizer = tokenizer.tokenizer |
| | self.tokenizer = tokenizer |
| | self.tokenizer.padding_side = "left" |
| | self.actor_rollout_wg = actor_rollout_wg |
| | self.config = config |
| | self.is_validation = is_validation |
| | |
| | assert tokenizer.pad_token_id is not None |
| | self.tensor_fn = TensorHelper(TensorConfig( |
| | pad_token_id=tokenizer.pad_token_id, |
| | max_prompt_length=config.max_prompt_length, |
| | max_obs_length=config.max_obs_length, |
| | max_start_length=config.max_start_length |
| | )) |
| |
|
| | |
| | try: |
| | im_end_ids = self.tokenizer.encode("<|im_end|>", add_special_tokens=False) |
| | if isinstance(im_end_ids, list) and len(im_end_ids) == 1: |
| | self.chat_eos_token_id = im_end_ids[0] |
| | else: |
| | self.chat_eos_token_id = self.tokenizer.eos_token_id |
| | except Exception: |
| | self.chat_eos_token_id = self.tokenizer.eos_token_id |
| | |
| | @abstractmethod |
| | def run_agent_loop(self, gen_batch: InteractionDataProto) -> InteractionDataProto: |
| | ... |