Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
from huggingface_hub import login, InferenceClient | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# 页面配置 | |
st.set_page_config( | |
page_title="MiniMind 聊天机器人", | |
page_icon="🤖", | |
layout="centered", | |
) | |
# 标题和说明 | |
st.title("🤖 MiniMind 聊天机器人") | |
st.markdown("这是基于MiniMind模型的聊天应用。输入你的问题,AI将为你解答!") | |
# 设置边栏 | |
with st.sidebar: | |
st.header("模型设置") | |
temperature = st.slider("温度", min_value=0.1, max_value=1.0, value=0.7, step=0.1, | |
help="较高的值使输出更随机,较低的值使其更确定") | |
max_tokens = st.slider("最大生成长度", min_value=64, max_value=1024, value=512, step=64) | |
st.markdown("---") | |
st.markdown("## 关于模型") | |
st.markdown(""" | |
MiniMind是一个轻量级语言模型,可以进行文本生成、问答和聊天等任务。 | |
[查看模型主页](https://huggingface.co/xingyu1996/minimind) | |
""") | |
# 初始化聊天历史 | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
# 显示聊天历史 | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
# 用户输入 | |
if prompt := st.chat_input("输入你的问题..."): | |
# 添加用户消息到历史 | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
# 显示用户消息 | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
# 显示助手消息占位符 | |
with st.chat_message("assistant"): | |
message_placeholder = st.empty() | |
full_response = "" | |
# 加载模型(两种方式,根据需要选择) | |
try: | |
with st.spinner("思考中..."): | |
def load_model(): | |
"""加载模型和分词器(使用缓存避免重复加载)""" | |
tokenizer = AutoTokenizer.from_pretrained("xingyu1996/minimind") | |
model = AutoModelForCausalLM.from_pretrained( | |
"xingyu1996/minimind", | |
trust_remote_code=True, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto" | |
) | |
return model, tokenizer | |
# 加载模型和分词器 | |
model, tokenizer = load_model() | |
# 构建聊天历史 | |
messages = [] | |
for msg in st.session_state.messages: | |
messages.append(msg) | |
# 生成回复 | |
prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device) | |
# 逐步生成并显示输出 | |
full_response = "" | |
output_ids = model.generate( | |
inputs.input_ids, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=0.9, | |
do_sample=True, | |
pad_token_id=tokenizer.pad_token_id, | |
eos_token_id=tokenizer.eos_token_id | |
) | |
# 解码生成的文本 | |
full_response = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
# 显示最终响应 | |
message_placeholder.markdown(full_response) | |
except Exception as e: | |
st.error(f"发生错误: {str(e)}") | |
full_response = f"抱歉,生成回复时出错。错误信息: {str(e)}" | |
message_placeholder.markdown(full_response) | |
# 将助手回复添加到会话历史 | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |
# 添加重置按钮 | |
if st.button("清空对话"): | |
st.session_state.messages = [] | |
st.experimental_rerun() | |
# 页脚 | |
st.markdown("---") | |
st.markdown("Made with ❤️ by [xingyu1996](https://huggingface.co/xingyu1996)") | |