| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import json |
| import re |
|
|
| from ...utils.constants import IGNORE_INDEX |
| from ...utils.helper import get_tokenizer |
| from ...utils.plugin import BasePlugin |
| from ...utils.types import Message, ModelInput, Processor, ToolCall |
|
|
|
|
| class RenderingPlugin(BasePlugin): |
| def render_messages( |
| self, |
| processor: Processor, |
| messages: list[Message], |
| tools: str | None = None, |
| is_generate: bool = False, |
| ) -> ModelInput: |
| """Render messages in the template format.""" |
| return self["render_messages"](processor, messages, tools, is_generate) |
|
|
| def parse_messages(self, generated_text: str) -> Message: |
| """Parse messages in the template format.""" |
| return self["parse_messages"](generated_text) |
|
|
|
|
| def _update_model_input( |
| processor: Processor, |
| input_ids: list[int], |
| labels: list[int], |
| loss_weights: list[int], |
| temp_str: str, |
| temp_weight: float, |
| ) -> str: |
| """Update model input with temporary string.""" |
| if not temp_str: |
| return "" |
|
|
| tokenizer = get_tokenizer(processor) |
| temp_ids = tokenizer.encode(temp_str, add_special_tokens=False) |
| input_ids.extend(temp_ids) |
| loss_weights.extend([temp_weight] * len(temp_ids)) |
| if temp_weight > 1e-6: |
| labels.extend(temp_ids) |
| else: |
| labels.extend([IGNORE_INDEX] * len(temp_ids)) |
|
|
| return "" |
|
|
|
|
| @RenderingPlugin("qwen3_nothink").register("render_messages") |
| def render_qwen3_nothink_messages( |
| processor: Processor, |
| messages: list[Message], |
| tools: str | None = None, |
| is_generate: bool = False, |
| ) -> ModelInput: |
| """Render messages in the Qwen3 nothink template format. |
| |
| See https://huggingface.co/spaces/huggingfacejs/chat-template-playground?modelId=Qwen/Qwen3-4B-Instruct-2507 |
| """ |
| input_ids, labels, loss_weights = [], [], [] |
| temp_str, temp_weight = "", 0.0 |
| if tools: |
| temp_str += "<|im_start|>system\n" |
| if messages[0]["role"] == "system": |
| for content in messages[0]["content"]: |
| if content["type"] == "text": |
| temp_str += content["value"] |
| else: |
| raise ValueError(f"Unsupported content type: {content['type']}") |
|
|
| temp_str += "\n\n" |
| temp_weight = messages[0].get("loss_weight", 0.0) |
|
|
| temp_str += ( |
| "# Tools\n\nYou may call one or more functions to assist with the user query.\n\n" |
| "You are provided with function signatures within <tools></tools> XML tags:\n<tools>" |
| ) |
| try: |
| tools = json.loads(tools) |
| except json.JSONDecodeError: |
| raise ValueError(f"Invalid tools format: {str(tools)}.") |
|
|
| if not isinstance(tools, list): |
| tools = [tools] |
|
|
| for tool in tools: |
| temp_str += "\n" + json.dumps(tool, ensure_ascii=False) |
|
|
| temp_str += ( |
| "\n</tools>\n\nFor each function call, return a json object with function name " |
| 'and arguments within <tool_call></tool_call> XML tags:\n<tool_call>\n{"name": ' |
| '<function-name>, "arguments": <args-json-object>}\n</tool_call><|im_end|>\n' |
| ) |
| elif messages[0]["role"] == "system": |
| temp_str += "<|im_start|>system\n" |
| for content in messages[0]["content"]: |
| if content["type"] == "text": |
| temp_str += content["value"] |
| else: |
| raise ValueError(f"Unsupported content type: {content['type']}") |
|
|
| temp_str += "<|im_end|>\n" |
| temp_weight = messages[0].get("loss_weight", 0.0) |
|
|
| temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) |
|
|
| for turn_idx, message in enumerate(messages): |
| if message["role"] == "user" or (message["role"] == "system" and turn_idx != 0): |
| temp_str += "<|im_start|>" + message["role"] + "\n" |
| for content in message["content"]: |
| if content["type"] == "text": |
| temp_str += content["value"] |
| else: |
| raise ValueError(f"Unsupported content type: {content['type']}") |
|
|
| temp_str += "<|im_end|>\n" |
| temp_weight = message.get("loss_weight", 0.0) |
| elif message["role"] == "assistant": |
| temp_str += "<|im_start|>" + message["role"] + "\n" |
| for val_idx, content in enumerate(message["content"]): |
| if content["type"] == "text": |
| temp_str += content["value"] |
| elif content["type"] == "reasoning": |
| temp_str += "<thinking>\n" + content["value"] + "\n</thinking>\n\n" |
| elif content["type"] == "tool_call": |
| if val_idx != 0 and message["content"][val_idx - 1]["type"] in ["text", "tool_call"]: |
| temp_str += "\n" |
|
|
| try: |
| tool_call: ToolCall = json.loads(content["value"]) |
| except json.JSONDecodeError: |
| raise ValueError(f"Invalid tool call format: {content['value']}.") |
|
|
| temp_str += ( |
| '<tool_call>\n{"name": "' |
| + tool_call["name"] |
| + '", "arguments": ' |
| + json.dumps(tool_call["arguments"], ensure_ascii=False) |
| + "}\n</tool_call>" |
| ) |
|
|
| else: |
| raise ValueError(f"Unsupported content type: {content['type']}") |
|
|
| temp_str += "<|im_end|>\n" |
| temp_weight = message.get("loss_weight", 1.0) |
| elif message["role"] == "tool": |
| if turn_idx == 0 or messages[turn_idx - 1]["role"] != "tool": |
| temp_str += "<|im_start|>user" |
|
|
| temp_str += "\n<tool_response>\n" |
| for content in message["content"]: |
| if content["type"] == "text": |
| temp_str += content["value"] |
| else: |
| raise ValueError(f"Unsupported content type: {content['type']}") |
|
|
| temp_str += "\n</tool_response>" |
| if turn_idx == len(messages) - 1 or messages[turn_idx + 1]["role"] != "tool": |
| temp_str += "<|im_end|>\n" |
|
|
| temp_weight = message.get("loss_weight", 0.0) |
|
|
| temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) |
|
|
| if is_generate: |
| temp_str += "<|im_start|>assistant\n" |
| temp_weight = 0.0 |
|
|
| temp_str = _update_model_input(processor, input_ids, labels, loss_weights, temp_str, temp_weight) |
|
|
| attention_mask = [1] * len(input_ids) |
| return ModelInput( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| loss_weights=loss_weights, |
| ) |
|
|
|
|
| @RenderingPlugin("qwen3_nothink").register("parse_message") |
| def parse_qwen3_nothink_message(generated_text: str) -> Message: |
| """Parse a message in the Qwen3 nothink template format. Supports interleaved reasoning and tool calls. |
| |
| Args: |
| generated_text (str): The generated text in the Qwen3 nothink template format. |
| |
| Returns: |
| Message: The parsed message. |
| """ |
| pattern = re.compile(r"<(thinking|tool_call)>\s*(.*?)\s*</\1>\s*", re.DOTALL) |
| content = [] |
| last_end = 0 |
| for match in pattern.finditer(generated_text): |
| start, end = match.span() |
| if start > last_end: |
| text = generated_text[last_end:start].strip() |
| if text: |
| content.append({"type": "text", "value": text}) |
|
|
| tag_type = match.group(1) |
| tag_value = match.group(2).strip() |
| if tag_type == "thinking": |
| content.append({"type": "reasoning", "value": tag_value.strip()}) |
| elif tag_type == "tool_call": |
| try: |
| json.loads(tag_value.strip()) |
| except json.JSONDecodeError: |
| raise ValueError(f"Invalid tool call format: {tag_value.strip()}.") |
|
|
| content.append({"type": "tool_call", "value": tag_value.strip()}) |
|
|
| last_end = end |
|
|
| if last_end < len(generated_text): |
| text = generated_text[last_end:].strip() |
| if text: |
| content.append({"type": "text", "value": text}) |
|
|
| return Message(role="assistant", content=content) |
|
|