Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig | |
# Load model and tokenizer | |
model_name = "abooze/ft-deepseek-llm-7b-chat-dpo-pairs" | |
st.title("π¬ RealMind AI") | |
st.markdown("Chat with RealMind AI!") | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype= torch.float32, | |
device_map="auto", | |
trust_remote_code=True | |
) | |
gen_config = GenerationConfig.from_pretrained(model_name) | |
gen_config.pad_token_id = gen_config.eos_token_id | |
return tokenizer, model, gen_config | |
tokenizer, model, gen_config = load_model() | |
# Session state to hold chat history | |
# if "messages" not in st.session_state: | |
# st.session_state.messages = [ | |
# {"role": "system", "content": "You are a helpful assistant."} | |
# ] | |
# # Display chat history | |
# for msg in st.session_state.messages: | |
# if msg["role"] != "system": | |
# st.chat_message(msg["role"]).write(msg["content"]) | |
# Chat input | |
user_input = st.chat_input("Ask something...") | |
if user_input: | |
# Add user input to message history | |
st.session_state.messages = [{"role": "user", "content": user_input}] | |
st.chat_message("user").write(user_input) | |
with st.spinner("Generating response..."): | |
input_ids = tokenizer.apply_chat_template( | |
st.session_state.messages, | |
return_tensors="pt", | |
add_generation_prompt=True | |
).to(model.device) | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=512, | |
temperature=0.7, | |
top_p=0.95, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
# Decode and extract response | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract only the assistant's last reply | |
assistant_reply = response.split("<|assistant|>\n")[-1].strip() | |
st.chat_message("assistant").write(assistant_reply) | |
st.session_state.messages.append({"role": "assistant", "content": assistant_reply}) | |