minimind / app.py
xingyu1996's picture
Update app.py
6016ac8 verified
import streamlit as st
import os
from huggingface_hub import login, InferenceClient
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
# 页面配置
st.set_page_config(
page_title="MiniMind 聊天机器人",
page_icon="🤖",
layout="centered",
)
# 标题和说明
st.title("🤖 MiniMind 聊天机器人")
st.markdown("这是基于MiniMind模型的聊天应用。输入你的问题,AI将为你解答!")
# 设置边栏
with st.sidebar:
st.header("模型设置")
temperature = st.slider("温度", min_value=0.1, max_value=1.0, value=0.7, step=0.1,
help="较高的值使输出更随机,较低的值使其更确定")
max_tokens = st.slider("最大生成长度", min_value=64, max_value=1024, value=512, step=64)
st.markdown("---")
st.markdown("## 关于模型")
st.markdown("""
MiniMind是一个轻量级语言模型,可以进行文本生成、问答和聊天等任务。
[查看模型主页](https://huggingface.co/xingyu1996/minimind)
""")
# 初始化聊天历史
if "messages" not in st.session_state:
st.session_state.messages = []
# 显示聊天历史
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# 用户输入
if prompt := st.chat_input("输入你的问题..."):
# 添加用户消息到历史
st.session_state.messages.append({"role": "user", "content": prompt})
# 显示用户消息
with st.chat_message("user"):
st.markdown(prompt)
# 显示助手消息占位符
with st.chat_message("assistant"):
message_placeholder = st.empty()
full_response = ""
# 加载模型(两种方式,根据需要选择)
try:
with st.spinner("思考中..."):
@st.cache_resource
def load_model():
"""加载模型和分词器(使用缓存避免重复加载)"""
tokenizer = AutoTokenizer.from_pretrained("xingyu1996/minimind")
model = AutoModelForCausalLM.from_pretrained(
"xingyu1996/minimind",
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
return model, tokenizer
# 加载模型和分词器
model, tokenizer = load_model()
# 构建聊天历史
messages = []
for msg in st.session_state.messages:
messages.append(msg)
# 生成回复
prompt_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
# 逐步生成并显示输出
full_response = ""
output_ids = model.generate(
inputs.input_ids,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id
)
# 解码生成的文本
full_response = tokenizer.decode(output_ids[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# 显示最终响应
message_placeholder.markdown(full_response)
except Exception as e:
st.error(f"发生错误: {str(e)}")
full_response = f"抱歉,生成回复时出错。错误信息: {str(e)}"
message_placeholder.markdown(full_response)
# 将助手回复添加到会话历史
st.session_state.messages.append({"role": "assistant", "content": full_response})
# 添加重置按钮
if st.button("清空对话"):
st.session_state.messages = []
st.experimental_rerun()
# 页脚
st.markdown("---")
st.markdown("Made with ❤️ by [xingyu1996](https://huggingface.co/xingyu1996)")