import os import json import gradio as gr from datetime import datetime from threading import Lock from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM import torch # ========== Auto-create folders ========== os.makedirs("chat_history", exist_ok=True) os.makedirs("system", exist_ok=True) # ========== Load System Context ========== context_path = "system/context.txt" if not os.path.exists(context_path): raise FileNotFoundError(f"Missing system context file at {context_path}!") with open(context_path, "r", encoding="utf-8") as f: loaded_context = f.read() # ========== Simple Chatbot Logic ========== lock = Lock() # Provide the folder path, not the file path model_folder = "model/Mistral-7B-Instruct-v0.3" # Load the model and tokenizer model = AutoModelForCausalLM.from_pretrained(model_folder, torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained(model_folder) # Set pad_token to eos_token if pad_token is not available if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Initialize the pipeline for text generation generator = pipeline("text-generation", model=model, tokenizer=tokenizer) # ========== Helper Functions ========== def sanitize_username(username): return ''.join(c for c in username if c.isalnum() or c in ('_', '-')).strip() def user_folder(username): return os.path.join("chat_history", username) def load_latest_history(username): folder = user_folder(username) if not os.path.exists(folder): os.makedirs(folder, exist_ok=True) return [] files = sorted(os.listdir(folder), reverse=True) if not files: return [] latest_file = os.path.join(folder, files[0]) with open(latest_file, "r", encoding="utf-8") as f: lines = f.readlines() history = [] for line in lines: if ": " in line: user, msg = line.split(": ", 1) history.append((user.strip(), msg.strip())) return history def save_history(username, history): folder = user_folder(username) os.makedirs(folder, exist_ok=True) filepath = os.path.join(folder, "history.txt") with open(filepath, "a", encoding="utf-8") as f: # Only write the last two new entries (user + Sanny Lin) for user, msg in history[-2:]: f.write(f"{user}: {msg}\n") def format_chat(history): formatted = "" for user, msg in history: if user == "Sanny Lin": formatted += f"""
{msg}
""" else: formatted += f"""
{msg}
""" return formatted def generate_reply(username, user_message, history): with lock: if not user_message.strip(): return history # Retrieve the last 30 messages, including history from the user history = history[-30:] # Limit to the last 30 messages messages = [] # Start with the system context if not history: messages.append({"role": "system", "content": loaded_context}) # Add the last 30 messages to the conversation history for user, msg in history: role = "user" if user == username else "assistant" messages.append({"role": role, "content": msg}) # Add the user message at the end messages.append({"role": "user", "content": user_message}) # Append the personalized prompt "You are chatting with {{ username }} now:" at the end of the context user_prompt = f"You are chatting with {username} now. Reply to this message:" messages.append({"role": "system", "content": user_prompt}) # Extract the content part of each message for encoding text_messages = [message["content"] for message in messages] # Tokenize using only the content part prompt = tokenizer.batch_encode_plus(text_messages, return_tensors="pt", padding=True, truncation=False) # Generate the assistant's reply without the user message being included at the start generated_output = generator(user_message, max_length=32768, max_new_tokens=512,# Set max length for truncation num_return_sequences=1, do_sample=True, temperature=0.5, top_p=0.5, top_k=0, typical_p=1, repetition_penalty=1) # Disable sampling for more creative and deterministic responses response = generated_output[0]["generated_text"] # Clean the response to remove any prefix from the last user message if response.startswith(user_message): response = response[len(user_message):].strip() # Smart truncation to cut off at 4096 characters without cutting in the middle of a word max_length = 4096 if len(response) > max_length: # Find the last space before the cutoff point truncated_response = response[:max_length] last_space_idx = truncated_response.rfind(" ") if last_space_idx != -1: response = truncated_response[:last_space_idx] else: response = truncated_response # Add the user message and assistant's response to history history.append((username, user_message)) history.append(("Sanny Lin", response)) save_history(username, history) return format_chat(history) # ========== Gradio Interface ========== with gr.Blocks(theme=gr.themes.Monochrome(), css=""" @font-face { font-family: "DaemonFont"; src: url('static/daemon.otf') format('opentype'); } body { background-color: #121212 !important; } .gradio-container { background-color: #121212 !important; } textarea { background-color: #1e1e1e !important; color: white; } input { background-color: #1e1e1e !important; color: white; } #chat_display { overflow-y: auto; height: calc(100vh - 200px); } .sanny-message { font-family: "DaemonFont", sans-serif; } """) as demo: chat_display = gr.HTML(value="", elem_id="chat_display", show_label=False) with gr.Row(): username_box = gr.Textbox(label="Username", placeholder="Enter username...", interactive=True, scale=2) user_input = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False, scale=8) send_button = gr.Button("Send", scale=1) username_state = gr.State("") history_state = gr.State([]) def user_send(user_message, username, history, username_input): if not username_input.strip(): return "
Please enter a valid username first.
", history, username username_input = sanitize_username(username_input) if not username: username = username_input history = history or load_latest_history(username) return generate_reply(username, user_message, history), history, username send_button.click( fn=user_send, inputs=[user_input, username_state, history_state, username_box], outputs=[chat_display, history_state, username_state] ) send_button.click(lambda: "", None, user_input) # Clear input after send demo.load(None, None, None, js=""" () => { const textbox = document.querySelector('textarea'); const sendButton = document.querySelector('button'); textbox.addEventListener('keydown', function(e) { if (e.key === 'Enter' && !e.shiftKey) { e.preventDefault(); sendButton.click(); } }); } """) demo.launch(share=False)