# Copyright (c) Alibaba, Inc. and its affiliates. from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any, Dict, List, Tuple from swift.plugin.orm import MathAccuracy if TYPE_CHECKING: from swift.llm.template import RolloutInferRequest from swift.llm.utils import Messages class Env(ABC): """Base environment interface for GRPO training.""" def __init__(self, env_config): """Initialize environment.""" self.env_config = env_config @abstractmethod async def reset(self, config: 'RolloutInferRequest') -> Tuple[str, Dict[str, Any], str]: """Reset environment to initial state. Args: config: Initial configuration containing dataset information Returns: Tuple of (observation, info, system_message): - observation: Initial query string for the agent - info: Environment debug information as dict - system_message: System prompt for this trajectory """ pass @abstractmethod async def step(self, action: 'Messages') -> Tuple[str, float, bool, Dict[str, Any]]: """Execute one step in the environment. Args: action: LLM response choice containing the action to execute Returns: Tuple of (next_observation, reward, done, info): - next_observation: Next observation string - reward: Reward value for this step - done: Whether the episode is finished - info: Additional information as dict """ pass @abstractmethod async def close(self): """Clean up environment resources.""" pass def count_qwen_tokens(messages: List[Dict[str, Any]], max_tokens: int = 2048) -> Tuple[int, bool]: """ Calculate token count for Qwen messages and check if it exceeds the 16k limit Args: messages: List of messages in OpenAI format max_tokens: Maximum token limit, default 2k Returns: Tuple[int, bool]: (token count, whether within limit) """ try: from modelscope import AutoTokenizer model_name = 'Qwen/Qwen2.5-3B-Instruct' tokenizer = AutoTokenizer.from_pretrained(model_name) text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) token_count = len(tokenizer.encode(text)) return token_count, token_count >= max_tokens except Exception as e: print(f'Token calculation failed: {e}') return 0, False class SimpleMathEnv(Env): tips_prompt = 'The answer is not correct, It seems You made a mistake, you need to recheck very carefully.' def __init__(self, env_config): super().__init__(env_config) self.acc_func = MathAccuracy() self.solution = '' async def reset(self, config: 'RolloutInferRequest') -> Tuple[str, Dict[str, Any], str]: obs = config.data_dict['problem'] info = {} self.solution = config.data_dict['solution'] system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here """ return obs, info, system_prompt async def step(self, action: 'Messages') -> Tuple[str, float, bool, Dict[str, Any]]: next_obs = self.tips_prompt reward = 0.0 done = False info = {} acc = self.acc_func([action[-1]['content']], [self.solution])[0] if count_qwen_tokens(action)[1]: done = True info['stop_reason'] = 'Exceeded maximum length' if acc == 1: done = True reward = 1.0 info['stop_reason'] = 'Correct' info['math_reward'] = reward return next_obs, reward, done, info async def close(self): pass # Registry for environments envs = {'math_env': SimpleMathEnv}