cafe3310's picture
refactor: 重构项目结构并优化模型加载方式
551e9e2
raw
history blame
1.99 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain_core.messages import AIMessage
from typing import TypedDict, Annotated, List
import operator
# 定义此组件操作的图状态的子集
class GraphState(TypedDict):
messages: Annotated[List[AIMessage], operator.add]
# --- 模型加载 ---
# 使用 "auto" 模式加载模型和分词器,Hugging Face Accelerate 会自动处理设备和精度
MODEL_NAME = "inclusionAI/Ring-mini-2.0"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True
)
def completion_node(state: GraphState) -> dict:
"""
一个调用语言模型以获取响应的节点。
Args:
state (GraphState): 图的当前状态,包含消息历史。
Returns:
dict: 一个包含新 AI 消息的字典,该消息将被添加到状态中。
"""
messages = state["messages"]
# --- 提示工程 ---
# 从消息历史中组装提示。
prompt = ""
for msg in messages:
if msg.type == "system":
prompt += f"{msg.content}\n"
elif msg.type == "human":
prompt += f"User: {msg.content}\n"
elif msg.type == "ai":
prompt += f"Assistant: {msg.content}\n"
prompt += "Assistant:"
# --- 模型调用 ---
# 使用 device_map="auto" 时,我们无需手动将张量移动到特定设备
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output_ids = model.generate(
input_ids,
max_new_tokens=512, # 暂时硬编码
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
output = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
# 以 AIMessage 的形式返回响应,以添加到图的状态中。
return {"messages": [AIMessage(content=output)]}