Spaces:
Runtime error
Runtime error
import os | |
import json | |
import gradio as gr | |
from datetime import datetime | |
from threading import Lock | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# ========== Auto-create folders ========== | |
os.makedirs("chat_history", exist_ok=True) | |
os.makedirs("system", exist_ok=True) | |
# ========== Load System Context ========== | |
context_path = "system/context.txt" | |
if not os.path.exists(context_path): | |
raise FileNotFoundError(f"Missing system context file at {context_path}!") | |
with open(context_path, "r", encoding="utf-8") as f: | |
loaded_context = f.read() | |
# ========== Simple Chatbot Logic ========== | |
lock = Lock() | |
# Provide the folder path, not the file path | |
model_folder = "model/Mistral-7B-Instruct-v0.3" | |
# Load the model and tokenizer | |
model = AutoModelForCausalLM.from_pretrained(model_folder, torch_dtype=torch.bfloat16) | |
tokenizer = AutoTokenizer.from_pretrained(model_folder) | |
# Set pad_token to eos_token if pad_token is not available | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Initialize the pipeline for text generation | |
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
# ========== Helper Functions ========== | |
def sanitize_username(username): | |
return ''.join(c for c in username if c.isalnum() or c in ('_', '-')).strip() | |
def user_folder(username): | |
return os.path.join("chat_history", username) | |
def load_latest_history(username): | |
folder = user_folder(username) | |
if not os.path.exists(folder): | |
os.makedirs(folder, exist_ok=True) | |
return [] | |
files = sorted(os.listdir(folder), reverse=True) | |
if not files: | |
return [] | |
latest_file = os.path.join(folder, files[0]) | |
with open(latest_file, "r", encoding="utf-8") as f: | |
lines = f.readlines() | |
history = [] | |
for line in lines: | |
if ": " in line: | |
user, msg = line.split(": ", 1) | |
history.append((user.strip(), msg.strip())) | |
return history | |
def save_history(username, history): | |
folder = user_folder(username) | |
os.makedirs(folder, exist_ok=True) | |
filepath = os.path.join(folder, "history.txt") | |
with open(filepath, "a", encoding="utf-8") as f: | |
# Only write the last two new entries (user + Sanny Lin) | |
for user, msg in history[-2:]: | |
f.write(f"{user}: {msg}\n") | |
def format_chat(history): | |
formatted = "" | |
for user, msg in history: | |
if user == "Sanny Lin": | |
formatted += f""" | |
<div style='text-align: left; margin: 5px;'> | |
<span class='sanny-message' style='background-color: #e74c3c; color: white; padding: 10px 15px; border-radius: 20px; display: inline-block; max-width: 70%; word-wrap: break-word;'> | |
{msg} | |
</span> | |
</div> | |
""" | |
else: | |
formatted += f""" | |
<div style='text-align: right; margin: 5px;'> | |
<span style='background-color: #3498db; color: white; padding: 10px 15px; border-radius: 20px; display: inline-block; max-width: 70%; word-wrap: break-word;'> | |
{msg} | |
</span> | |
</div> | |
""" | |
return formatted | |
def generate_reply(username, user_message, history): | |
with lock: | |
if not user_message.strip(): | |
return history | |
# Retrieve the last 30 messages, including history from the user | |
history = history[-30:] # Limit to the last 30 messages | |
messages = [] | |
# Start with the system context | |
if not history: | |
messages.append({"role": "system", "content": loaded_context}) | |
# Add the last 30 messages to the conversation history | |
for user, msg in history: | |
role = "user" if user == username else "assistant" | |
messages.append({"role": role, "content": msg}) | |
# Add the user message at the end | |
messages.append({"role": "user", "content": user_message}) | |
# Append the personalized prompt "You are chatting with {{ username }} now:" at the end of the context | |
user_prompt = f"You are chatting with {username} now. Reply to this message:" | |
messages.append({"role": "system", "content": user_prompt}) | |
# Extract the content part of each message for encoding | |
text_messages = [message["content"] for message in messages] | |
# Tokenize using only the content part | |
prompt = tokenizer.batch_encode_plus(text_messages, return_tensors="pt", padding=True, truncation=False) | |
# Generate the assistant's reply without the user message being included at the start | |
generated_output = generator(user_message, | |
max_length=32768, | |
max_new_tokens=512,# Set max length for truncation | |
num_return_sequences=1, | |
do_sample=True, | |
temperature=0.5, | |
top_p=0.5, | |
top_k=0, | |
typical_p=1, | |
repetition_penalty=1) # Disable sampling for more creative and deterministic responses | |
response = generated_output[0]["generated_text"] | |
# Clean the response to remove any prefix from the last user message | |
if response.startswith(user_message): | |
response = response[len(user_message):].strip() | |
# Smart truncation to cut off at 4096 characters without cutting in the middle of a word | |
max_length = 4096 | |
if len(response) > max_length: | |
# Find the last space before the cutoff point | |
truncated_response = response[:max_length] | |
last_space_idx = truncated_response.rfind(" ") | |
if last_space_idx != -1: | |
response = truncated_response[:last_space_idx] | |
else: | |
response = truncated_response | |
# Add the user message and assistant's response to history | |
history.append((username, user_message)) | |
history.append(("Sanny Lin", response)) | |
save_history(username, history) | |
return format_chat(history) | |
# ========== Gradio Interface ========== | |
with gr.Blocks(theme=gr.themes.Monochrome(), css=""" | |
@font-face { | |
font-family: "DaemonFont"; | |
src: url('static/daemon.otf') format('opentype'); | |
} | |
body { background-color: #121212 !important; } | |
.gradio-container { background-color: #121212 !important; } | |
textarea { background-color: #1e1e1e !important; color: white; } | |
input { background-color: #1e1e1e !important; color: white; } | |
#chat_display { overflow-y: auto; height: calc(100vh - 200px); } | |
.sanny-message { | |
font-family: "DaemonFont", sans-serif; | |
} | |
""") as demo: | |
chat_display = gr.HTML(value="", elem_id="chat_display", show_label=False) | |
with gr.Row(): | |
username_box = gr.Textbox(label="Username", placeholder="Enter username...", interactive=True, scale=2) | |
user_input = gr.Textbox(placeholder="Type your message...", lines=2, show_label=False, scale=8) | |
send_button = gr.Button("Send", scale=1) | |
username_state = gr.State("") | |
history_state = gr.State([]) | |
def user_send(user_message, username, history, username_input): | |
if not username_input.strip(): | |
return "<div style='color: red;'>Please enter a valid username first.</div>", history, username | |
username_input = sanitize_username(username_input) | |
if not username: | |
username = username_input | |
history = history or load_latest_history(username) | |
return generate_reply(username, user_message, history), history, username | |
send_button.click( | |
fn=user_send, | |
inputs=[user_input, username_state, history_state, username_box], | |
outputs=[chat_display, history_state, username_state] | |
) | |
send_button.click(lambda: "", None, user_input) # Clear input after send | |
demo.load(None, None, None, js=""" | |
() => { | |
const textbox = document.querySelector('textarea'); | |
const sendButton = document.querySelector('button'); | |
textbox.addEventListener('keydown', function(e) { | |
if (e.key === 'Enter' && !e.shiftKey) { | |
e.preventDefault(); | |
sendButton.click(); | |
} | |
}); | |
} | |
""") | |
demo.launch(share=False) | |