Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from langchain_core.messages import AIMessage | |
| from typing import TypedDict, Annotated, List | |
| import operator | |
| # 定义此组件操作的图状态的子集 | |
| class GraphState(TypedDict): | |
| messages: Annotated[List[AIMessage], operator.add] | |
| # --- 模型加载 --- | |
| # 使用 "auto" 模式加载模型和分词器,Hugging Face Accelerate 会自动处理设备和精度 | |
| MODEL_NAME = "inclusionAI/Ring-mini-2.0" | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| def completion_node(state: GraphState) -> dict: | |
| """ | |
| 一个调用语言模型以获取响应的节点。 | |
| Args: | |
| state (GraphState): 图的当前状态,包含消息历史。 | |
| Returns: | |
| dict: 一个包含新 AI 消息的字典,该消息将被添加到状态中。 | |
| """ | |
| messages = state["messages"] | |
| # --- 提示工程 --- | |
| # 从消息历史中组装提示。 | |
| prompt = "" | |
| for msg in messages: | |
| if msg.type == "system": | |
| prompt += f"{msg.content}\n" | |
| elif msg.type == "human": | |
| prompt += f"User: {msg.content}\n" | |
| elif msg.type == "ai": | |
| prompt += f"Assistant: {msg.content}\n" | |
| prompt += "Assistant:" | |
| # --- 模型调用 --- | |
| # 使用 device_map="auto" 时,我们无需手动将张量移动到特定设备 | |
| input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=512, # 暂时硬编码 | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True) | |
| # 以 AIMessage 的形式返回响应,以添加到图的状态中。 | |
| return {"messages": [AIMessage(content=output)]} |