SannyChatMini / main.py
taellinglin's picture
Upload 3 files
0be3d69 verified
import os
import json
import gradio as gr
from datetime import datetime
from threading import Lock
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
# ========== Auto-create folders ==========
os.makedirs("chat_history", exist_ok=True)
os.makedirs("system", exist_ok=True)
# ========== Load System Context ==========
context_path = "system/context.txt"
if not os.path.exists(context_path):
raise FileNotFoundError(f"Missing system context file at {context_path}!")
with open(context_path, "r", encoding="utf-8") as f:
loaded_context = f.read()
# ========== Simple Chatbot Logic ==========
lock = Lock()
# Provide the folder path, not the file path
model_folder = "model/Mistral-7B-Instruct-v0.3"
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_folder, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_folder)
# Set pad_token to eos_token if pad_token is not available
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Initialize the pipeline for text generation
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
# ========== Helper Functions ==========
def sanitize_username(username):
return ''.join(c for c in username if c.isalnum() or c in ('_', '-')).strip()
def user_folder(username):
return os.path.join("chat_history", username)
def load_latest_history(username):
folder = user_folder(username)
if not os.path.exists(folder):
os.makedirs(folder, exist_ok=True)
return []
files = sorted(os.listdir(folder), reverse=True)
if not files:
return []
latest_file = os.path.join(folder, files[0])
with open(latest_file, "r", encoding="utf-8") as f:
lines = f.readlines()
history = []
for line in lines:
if ": " in line:
user, msg = line.split(": ", 1)
history.append((user.strip(), msg.strip()))
return history
def save_history(username, history):
folder = user_folder(username)
os.makedirs(folder, exist_ok=True)
filepath = os.path.join(folder, "history.txt")
with open(filepath, "a", encoding="utf-8") as f:
# Only write the last two new entries (user + Sanny Lin)
for user, msg in history[-2:]:
f.write(f"{user}: {msg}\n")
def format_chat(history):
formatted = ""
for user, msg in history:
if user == "Sanny Lin":
formatted += f"""
<div style='text-align: left; margin: 5px;'>
<span class='sanny-message' style='background-color: #e74c3c; color: white; padding: 10px 15px; border-radius: 20px; display: inline-block; max-width: 70%; word-wrap: break-word;'>
{msg}
</span>
</div>
"""
else:
formatted += f"""
<div style='text-align: right; margin: 5px;'>
<span style='background-color: #3498db; color: white; padding: 10px 15px; border-radius: 20px; display: inline-block; max-width: 70%; word-wrap: break-word;'>
{msg}
</span>
</div>
"""
return formatted
def generate_reply(username, user_message, history):
with lock:
if not user_message.strip():
return history
# Retrieve the last 30 messages, including history from the user
history = history[-30:] # Limit to the last 30 messages
messages = []
# Start with the system context
if not history:
messages.append({"role": "system", "content": loaded_context})
# Add the last 30 messages to the conversation history
for user, msg in history:
role = "user" if user == username else "assistant"
messages.append({"role": role, "content": msg})
# Add the user message at the end
messages.append({"role": "user", "content": user_message})
# Append the personalized prompt "You are chatting with {{ username }} now:" at the end of the context
user_prompt = f"You are chatting with {username} now. Reply to this message:"
messages.append({"role": "system", "content": user_prompt})
# Extract the content part of each message for encoding
text_messages = [message["content"] for message in messages]
# Tokenize using only the content part
prompt = tokenizer.batch_encode_plus(text_messages, return_tensors="pt", padding=True, truncation=False)
# Generate the assistant's reply without the user message being included at the start
generated_output = generator(user_message,
max_length=32768,
max_new_tokens=512,# Set max length for truncation
num_return_sequences=1,
do_sample=True,
temperature=0.5,
top_p=0.5,
top_k=0,
typical_p=1,
repetition_penalty=1) # Disable sampling for more creative and deterministic responses
response = generated_output[0]["generated_text"]
# Clean the response to remove any prefix from the last user message
if response.startswith(user_message):
response = response[len(user_message):].strip()
# Smart truncation to cut off at 4096 characters without cutting in the middle of a word
max_length = 4096
if len(response) > max_length:
# Find the last space before the cutoff point
truncated_response = response[:max_length]
last_space_idx = truncated_response.rfind(" ")
if last_space_idx != -1:
response = truncated_response[:last_space_idx]
else:
response = truncated_response
# Add the user message and assistant's response to history
history.append((username, user_message))
history.append(("Sanny Lin", response))
save_history(username, history)
return format_chat(history)
# ========== Gradio Interface ==========
with gr.Blocks(theme=gr.themes.Monochrome(), css="""
@font-face {
font-family: "DaemonFont";
src: url('static/daemon.otf') format('opentype');
}
body { background-color: #121212 !important; }
.gradio-container { background-color: #121212 !important; }
textarea { background-color: #1e1e1e !important; color: white; }
input { background-color: #1e1e1e !important; color: white; }
#chat_display { overflow-y: auto; height: calc(100vh - 200px); }
.sanny-message {
font-family: "DaemonFont", sans-serif;
}
""") as demo:
chat_display = gr.HTML(value="", elem_id="chat_display", show_label=False)
with gr.Row():
username_box = gr.Textbox(label="Username", placeholder="Enter username...", interactive=True, scale=2)
user_input = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False, scale=8)
send_button = gr.Button("Send", scale=1)
username_state = gr.State("")
history_state = gr.State([])
def user_send(user_message, username, history, username_input):
if not username_input.strip():
return "<div style='color: red;'>Please enter a valid username first.</div>", history, username
username_input = sanitize_username(username_input)
if not username:
username = username_input
history = history or load_latest_history(username)
return generate_reply(username, user_message, history), history, username
send_button.click(
fn=user_send,
inputs=[user_input, username_state, history_state, username_box],
outputs=[chat_display, history_state, username_state]
)
send_button.click(lambda: "", None, user_input) # Clear input after send
demo.load(None, None, None, js="""
() => {
const textbox = document.querySelector('textarea');
const sendButton = document.querySelector('button');
textbox.addEventListener('keydown', function(e) {
if (e.key === 'Enter' && !e.shiftKey) {
e.preventDefault();
sendButton.click();
}
});
}
""")
demo.launch(share=False)